import dayjs from 'dayjs'
import { flatten, map, range, times } from 'lodash/fp'
import { supportedPlateFormats } from '~/components/TinyMicroplate.interface'
import { DemoSourceRecordsByKind } from '~/demoControls/demoData'
import imageSets from '~/demoControls/imageSets'
import {
  gaussianRandom,
  logisticFun,
  pseudoRandom,
} from '~/pages/slasDemo/events/data/generate/generateDataset'
import { convertWellCoordsToWellName, getPlateDims } from '~/utils/microplate'
import { getPlateAndWellToValues } from './hitpickingData'

// CONFLUENCY GENERATION

export const generateConfluencyValues = (numSamples: number): number[] => {
  return [...times(() => generateConfluencyForDataset(2, 1), numSamples)]
}
export const fittedGrowthCurveConstants = [
  { A: 15.52, L: 40.74, k: 1.2, t0: 2.62 },
  { A: 20.52, L: 70.74, k: 1.38, t0: 2.42 },
  { A: 18.52, L: 82.74, k: 1.38, t0: 2.32 },
  { A: 13.52, L: 90.74, k: 1.48, t0: 2.02 },
  { A: 1.52, L: 40.74, k: 0.9, t0: 4.62 },
  { A: 2.52, L: 76.74, k: 0.8, t0: 4.32 },
]
export function generateConfluencyForDataset(day: number, speed: number): number {
  // Try a standard deviation of 300 minutes for lots of noise.
  try {
    const stdDevInDays = 300 / 60 / 24
    const imagingTime = day + gaussianRandom(0, stdDevInDays)
    return Number(
      logisticFun(fittedGrowthCurveConstants[speed])(imagingTime).toFixed(1),
    )
  } catch (e) {
    throw new Error(`Failed to get confluence for ${day}`)
  }
}

// IMAGE GENERATION

const singleColonyImages = imageSets.find(value =>
  value.name.includes('Simulated single-colony'),
)!
const ipscImages = imageSets.find(value => value.name.includes('iPSC (Science Corp)'))!

// PLATE GENERATION

const combinePlates = (plates: DemoSourceRecordsByKind[]): DemoSourceRecordsByKind => {
  return {
    plates: plates.flatMap(plate => plate.plates ?? []),
    wells: plates.flatMap(plate => plate.wells ?? []),
    timepoints: plates.flatMap(plate => plate.timepoints ?? []),
  }
}

const plateFormatToImageSetIndex = (format: supportedPlateFormats): number => {
  return format === 'wells_96' ? 2 : format === 'wells_24' ? 1 : 0
}

// Day 0 through Day 9
// Last image is our >70%
// 0, 1, 2, 3, 7, 9

const plateImageSet = (
  barcode: string,
  well: string,
  confluence: number,
  format: supportedPlateFormats,
  day: number,
) => {
  if (barcode.includes('ACS-1020_LDK1-KO_1') && day < 8) {
    let imageSet = [0, 1, 2, 3, 7].indexOf(day)
    if (confluence < 50 && imageSet == 4) {
      imageSet = 3
    }
    return singleColonyImages.plateObservations?.[imageSet]?.spriteDirectory ?? ''
  } else {
    const imageSet = Math.floor(confluence / 14)
    return (
      ipscImages.plateObservations?.[imageSet * 3 + plateFormatToImageSetIndex(format)]
        ?.spriteDirectory ?? ''
    )
  }
}

const generatePlate = (
  barcode: string,
  format: supportedPlateFormats,
  seeded: dayjs.Dayjs,
  checkedOut: dayjs.Dayjs | null,
  days: number[],
  imageDays: number[],
  sourceWells: [string, string][],
  passageNumber: number = 1,
  numWells?: number,
  speeds: number[] = [1, 2, 3],
): DemoSourceRecordsByKind => {
  const { numRows, numCols } = getPlateDims(format)
  const wells = flatten(
    map(
      row => map(col => convertWellCoordsToWellName(row, col), range(0, numCols)),
      range(0, numRows),
    ),
  ).slice(0, numWells ?? numRows * numCols)
  const timepoints = wells.flatMap(well => {
    const speed = speeds[Math.floor(pseudoRandom() * speeds.length)]
    return days.map((day, i) => {
      const confluence = generateConfluencyForDataset(imageDays[i], speed)
      return {
        plateBarcode: barcode,
        wellPosition: well,
        timestamp: seeded.add(day, 'days').toISOString(),
        confluence: confluence.toString(),
        plateImageSet: plateImageSet(barcode, well, confluence, format, day),
      }
    })
  })
  return {
    plates: [
      {
        barcode,
        format: format.replace('wells_', ''),
        seeded: seeded.diff(dayjs(), 'days', true).toString(),
        checkedOut:
          checkedOut && checkedOut < dayjs()
            ? checkedOut.diff(dayjs(), 'days', true).toString()
            : '',
      },
    ],
    wells: wells.map((well, index) => ({
      plateBarcode: barcode,
      position: well,
      parentWellPlateBarcode: sourceWells[index]?.[0] ?? '',
      parentWellPosition: sourceWells[index]?.[1] ?? '',
      cellLine: 'ACS-1020-LDK1',
      cellLineLot: '1',
      passageNumber: passageNumber.toString(),
    })),
    timepoints,
  }
}

const generatePassagedPlate = (
  sourceWells: [string, string][],
  barcodePrefix: string,
  format: supportedPlateFormats,
  seeded: dayjs.Dayjs,
  checkedOut: dayjs.Dayjs | null,
  days: number[],
  passageNumber: number = 1,
  speeds?: number[],
): DemoSourceRecordsByKind => {
  const { numRows, numCols } = getPlateDims(format)
  const numWellsPerPlate = numRows * numCols
  const numPlates = Math.ceil(sourceWells.length / numWellsPerPlate)
  const plates = range(0, numPlates).map(plateNumber =>
    generatePlate(
      `${barcodePrefix}${String.fromCharCode('A'.charCodeAt(0) + plateNumber)}`,
      format,
      seeded,
      checkedOut,
      days,
      days,
      sourceWells,
      passageNumber + 1,
      undefined,
      speeds,
    ),
  )
  return combinePlates(plates)
}

// seeding time is 0
// passing time is 7 days
// expansion 1 time is 7+1 = 8 days, then 11 days
//   plate seeded day 7.2
// expansion 2 time 11+4 = 15 days
//   plate seeded day 11.2

// current time is x days after seeding time
// generated timestamp is current time - x days + time it was placed

// back an hour to account for time taken to generate the data
const currentTime = dayjs().subtract(3, 'hour')

const initialPlate = (day: number) =>
  generatePlate(
    'ACS-1020_LDK1-KO_1',
    'wells_96',
    currentTime.subtract(day * 24, 'hours'),
    currentTime.subtract((day - 9.2) * 24, 'hours'),
    [0, 1, 2, 3, 7, 9],
    [0, 1, 2, 3, 7, 9], // TODO: rejigger these confluences
    [],
    1,
    undefined,
    [4, 5, 5, 5],
  )

const initialPassages = (day: number) =>
  (initialPlate(day).timepoints ?? [])
    .filter(t => t.timestamp === currentTime.subtract(day - 9.2, 'days').toISOString())
    .sort((a, b) => parseFloat(b.confluence) - parseFloat(a.confluence))
    .slice(0, 72)
    .map(t => [t.plateBarcode, t.wellPosition] as [string, string])

const startingPassagedPlates = (day: number) =>
  generatePassagedPlate(
    initialPassages(day),
    'ACS-1020_LDK1-KO_2',
    'wells_24',
    currentTime.subtract((day - 9.2) * 24, 'hours'),
    currentTime.subtract((day - 13.2) * 24, 'hours'),
    [0],
    2,
  )

const fullyPassagedPlates = (day: number) =>
  generatePassagedPlate(
    initialPassages(day),
    'ACS-1020_LDK1-KO_2',
    'wells_24',
    currentTime.subtract((day - 9.2) * 24, 'hours'),
    currentTime.subtract((day - 13.2) * 24, 'hours'),
    [0, 1, 2, 3],
    2,
  )

const genotypingPlate = (day: number) =>
  generatePlate(
    'ACS-1020-LDK1-KO_2D',
    'wells_96',
    currentTime.subtract((day - 9.2) * 24, 'hours'),
    currentTime.subtract((day - 10.2) * 24, 'hours'),
    [0],
    [1],
    [],
    2,
    0,
  )

export const plateAndWellValues = getPlateAndWellToValues()
export const selectedHitPicksWells = Object.entries(plateAndWellValues)
  .sort((a, b) => {
    const aScore = a[1].doublingRate * 2 - a[1].koScore
    const bScore = b[1].doublingRate * 2 - b[1].koScore
    return aScore - bScore
  })
  .slice(0, 12)
const selectedHitPicksWellDict = selectedHitPicksWells.map(
  ([plateAndWell]) =>
    [
      'ACS-1020_LDK1-KO_2' +
        String.fromCharCode(
          'A'.charCodeAt(0) - '1'.charCodeAt(0) + plateAndWell.charCodeAt(0),
        ),
      plateAndWell.substring(plateAndWell.lastIndexOf('-') + 1),
    ] as [string, string],
)
const finalPlates = (day: number) => {
  const days = [0, 1, 2, 3].slice(0, day - 13)
  return generatePassagedPlate(
    selectedHitPicksWellDict,
    'ACS-1020_LDK1-KO_3',
    'wells_6',
    currentTime.subtract((day - 13.2) * 24, 'hours'),
    currentTime.subtract((day - 17.2) * 24, 'hours'),
    days,
    3,
    [3],
  )
}

export const initialDemoData = initialPlate(9)

export const postClonePickingDemoData = combinePlates([
  initialPlate(10),
  startingPassagedPlates(10),
  genotypingPlate(10),
])

export const preHitpickingData = combinePlates([
  initialPlate(13),
  fullyPassagedPlates(13),
])

export const postHitpickingData = combinePlates([
  initialPlate(14),
  fullyPassagedPlates(14),
  finalPlates(14),
])

export const finalDemoData = combinePlates([
  initialPlate(17),
  fullyPassagedPlates(17),
  finalPlates(17),
])
