import * as THREE from 'three'

import { assertUnitLength } from './geometry'

// TODO: Revise function so that rather than using ray intersection to find
// the point on the mesh, it searches for the closest point to the nudged
// point from points on a planar cross section of the mesh - the plane defined
// by the combination of the look and the nudge vectors.

// NOTE: originalPoint and returned point are in the geometry's coordinate
// system, as defined by the world transform on mesh. We compute the nudging in
// world space. `nudgeDirection`, `nudgeDistance`, and `maxSteepness` are all in
// world space.
export function nudgeAndReproject({
  raycaster,
  originalPoint,
  nudgeDirection,
  nudgeDistance,
  maxSteepness = 5,
  lookVector,
  mesh,
}: {
  raycaster: THREE.Raycaster
  originalPoint: THREE.Vector3
  nudgeDirection: THREE.Vector3
  nudgeDistance: number
  maxSteepness?: number
  lookVector: THREE.Vector3
  mesh: THREE.Mesh
}): THREE.Vector3 | undefined {
  assertUnitLength(nudgeDirection)
  assertUnitLength(lookVector)
  const nudgedPoint = originalPoint
    .clone()
    .applyMatrix4(mesh.matrixWorld)
    .addScaledVector(nudgeDirection, nudgeDistance)

  // reproject point towards the camera to ensure it isn't inside the mesh
  const reprojectedOrigin = nudgedPoint
    .clone()
    .addScaledVector(lookVector, -1000)
  raycaster.set(reprojectedOrigin, lookVector)
  // get the list of objects the ray intersected
  const intersectedObjects = raycaster.intersectObjects([mesh])
  if (intersectedObjects.length !== 0) {
    // If ray intersects multiple points on the mesh, return the point
    // that's closest to the origin point.
    const distancesToOrigin = intersectedObjects.map(point =>
      point.point.distanceTo(originalPoint)
    )
    const minIndex = distancesToOrigin.indexOf(Math.min(...distancesToOrigin))
    const chosenPoint = intersectedObjects[minIndex].point
    if (nudgedPoint.distanceTo(chosenPoint) >= maxSteepness * nudgeDistance) {
      return undefined
    } else {
      const invertedMatrix = mesh.matrixWorld.clone().invert()
      return chosenPoint.applyMatrix4(invertedMatrix)
    }
  }
}
