import { PartialMessage, PlainMessage } from '@bufbuild/protobuf'
import { useMutation, useQueryClient } from '@tanstack/react-query'

import { useGetCompany, useListInherentRiskCategories } from '@/api/company.hook'
import {
  getCompany,
  getCompanyActivityLog,
  getInherentRisk,
  listCompanies,
  listInherentRiskCategories,
  setIRQItem,
} from '@/gen/inventory/v1/company_service-CompanyInventoryService_connectquery'
import {
  GetInherentRiskResponse,
  IRQChangeType,
  InherentRiskCategoryEnum,
  InherentRiskItem,
  InherentRiskSource,
  SetIRQItemRequest,
} from '@/gen/inventory/v1/company_service_pb'

import { InherentRiskCategoryEnumName, changeTypeLabel } from '@/const/label'

import { useTrackCallback } from '@/lib/analytics/events'

export const useOptimisticSetIRQItem = (companyId: string) => {
  const { data: riskCategoryDefinitions } = useListInherentRiskCategories(companyId)
  const { data: company } = useGetCompany(companyId)
  const trackIRQSet = useTrackCallback('third-party.irq.set')
  const queryClient = useQueryClient()

  const mutation = setIRQItem.useMutation()

  return useMutation(mutation.mutationFn, {
    ...mutation,
    onMutate: async (variables) => {
      // Cancel any outgoing refetches
      await queryClient.cancelQueries(getInherentRisk.getQueryKey({ companyId }))

      // Snapshot the previous value
      const previousData = queryClient.getQueryData<PlainMessage<GetInherentRiskResponse>>(
        getInherentRisk.getQueryKey({ companyId }),
      )

      const categoryDefinition = riskCategoryDefinitions.find(
        (catGroup) => catGroup.categoryEnum === variables.category,
      )

      const itemDefinition = categoryDefinition?.items.find(
        (item) => item.id === variables.riskCategoryId,
      )

      trackIRQSet({
        companyId,
        companyName: company?.company?.profile?.name,
        changeType:
          changeTypeLabel[variables.changeType || IRQChangeType.IRQ_CHANGE_TYPE_UNSPECIFIED],
        displayName: itemDefinition?.displayName,
        category:
          InherentRiskCategoryEnumName[
            categoryDefinition?.categoryEnum || InherentRiskCategoryEnum.UNSPECIFIED
          ],
      })

      queryClient.setQueryData<PlainMessage<GetInherentRiskResponse> | undefined>(
        getInherentRisk.getQueryKey({ companyId }),
        (oldData) => updateInherentRiskData(oldData, variables, itemDefinition),
      )

      // Return context with the previous data for rollback
      return { previousData }
    },
    onError: (_error, _variables, context) => {
      // Roll back to the previous data
      if (context?.previousData) {
        queryClient.setQueryData(getInherentRisk.getQueryKey({ companyId }), context.previousData)
      }
    },
    onSettled: () => {
      queryClient.invalidateQueries(getCompanyActivityLog.getQueryKey({ companyId }))
      queryClient.invalidateQueries(listInherentRiskCategories.getQueryKey({ companyId }))
      queryClient.invalidateQueries(getInherentRisk.getQueryKey({ companyId }))
      queryClient.invalidateQueries(getCompany.getQueryKey({ id: companyId }))
      queryClient.invalidateQueries(listCompanies.getQueryKey())
    },
  })
}

const updateInherentRiskData = (
  oldData: PlainMessage<GetInherentRiskResponse> | undefined,
  variables: PartialMessage<SetIRQItemRequest>,
  itemDefinition?: PlainMessage<InherentRiskItem>,
): PlainMessage<GetInherentRiskResponse> | undefined => {
  if (!oldData) {
    return oldData
  }

  const { riskCategoryId: itemId, changeType, category: categoryEnum } = variables

  // Clone the old data to avoid direct mutations
  const newData: PlainMessage<GetInherentRiskResponse> = {
    ...oldData,
    inherentRiskGroups: oldData.inherentRiskGroups.map((group) => {
      // Skip groups that don't match the target category
      if (group.categoryEnum !== categoryEnum) {
        return group
      }

      if (changeType === IRQChangeType.IRQ_CHANGE_TYPE_REMOVE) {
        return {
          ...group,
          inherentRiskItems: group.inherentRiskItems.filter(
            (item) => !(item.id === itemId && item.source === InherentRiskSource.IRQ),
          ),
        }
      }

      if (changeType === IRQChangeType.IRQ_CHANGE_TYPE_ADD) {
        const itemAlreadyExists = group.inherentRiskItems.some(
          (item) => item.id === itemId && item.source === InherentRiskSource.IRQ,
        )
        if (itemAlreadyExists) {
          return group
        }

        if (itemDefinition) {
          const newItem: PlainMessage<InherentRiskItem> = {
            ...itemDefinition,
            source: InherentRiskSource.IRQ,
          }

          return {
            ...group,
            inherentRiskItems: [...group.inherentRiskItems, newItem],
          }
        }
      }

      return group
    }),
  }

  return newData
}
