import * as tfconv from '@tensorflow/tfjs-converter';
import * as tf from '@tensorflow/tfjs-core';
import { MobileNet, PoseNet } from '@tensorflow-models/posenet';

const MOBILENET_BASE_URL = `${window.location.origin}/tfjs-models/posenet/mobilenet/`;

/**
 * Creates the local model url based on the model configuration sent as parameters.
 * @param {number} stride
 * @param {number} multiplier
 * @param {number} quantBytes
 * @returns {string} The local model URL.
 */
const createLocalModelURL = (stride, multiplier, quantBytes) => {
  const versionStrings = { 1.0: '100', 0.75: '075', 0.50: '050' };
  const graphJson = `model-stride${stride}.json`;
  // quantBytes=4 corresponding to the non-quantized full-precision checkpoints.
  return quantBytes === 4 ? `${MOBILENET_BASE_URL}float/${versionStrings[multiplier]}/${graphJson}`
    : `${MOBILENET_BASE_URL}quant${quantBytes}/${versionStrings[multiplier]}/${graphJson}`;
};

/**
 * Checking whether input resolution is valid
 * @param {number} inputResolution
 */
const validateInputResolution = (inputResolution) => {
  tf.util.assert(
    typeof inputResolution === 'number' || typeof inputResolution === 'object',
    () => `Invalid inputResolution ${inputResolution}. Should be a number or an object with width and height`,
  );

  if (typeof inputResolution === 'object') {
    tf.util.assert(
      typeof inputResolution.width === 'number',
      () => `inputResolution.width has a value of ${inputResolution.width} which is invalid; it must be a number`,
    );
    tf.util.assert(
      typeof inputResolution.height === 'number',
      () => `inputResolution.height has a value of ${inputResolution.height} which is invalid; it must be a number`,
    );
  }
};

/**
 * Checks if the input resolution is valid or not.
 * @param {number} resolution
 * @param {number} outputStride
 * @returns {boolean} True if the input resolution is valid, otherwise false.
 */
const isValidInputResolution = (resolution, outputStride) => (resolution - 1) % outputStride === 0;

const toValidInputResolution = (inputResolution, outputStride) => {
  if (isValidInputResolution(inputResolution, outputStride)) {
    return inputResolution;
  }

  return Math.floor(inputResolution / outputStride) * outputStride + 1;
};

/**
 * Returns a valid set of input resolutions
 * @param inputResolution
 * @param outputStride
 * @returns {number[]}
 */
const getValidInputResolutionDimensions = (inputResolution, outputStride) => {
  validateInputResolution(inputResolution);
  if (typeof inputResolution === 'object') {
    return [
      toValidInputResolution(inputResolution.height, outputStride),
      toValidInputResolution(inputResolution.width, outputStride),
    ];
  }
  return [
    toValidInputResolution(inputResolution, outputStride),
    toValidInputResolution(inputResolution, outputStride),
  ];
};

/**
 * Loads MobileNet PoseNet models offline by using already downloaded and bundled local models.
 * @param {number} outputStride
 * @param {number} quantBytes
 * @param {number} multiplier
 * @param {string} modelUrl
 * @param {number} inputResolution
 * @returns {Promise<PoseNet>}
 */
const loadMobileNet = async ({
  outputStride,
  quantBytes,
  multiplier,
  modelUrl,
  inputResolution,
}) => {
  const url = createLocalModelURL(outputStride, multiplier, quantBytes);
  const graphModel = await tfconv.loadGraphModel(modelUrl || url);
  const mobilenet = new MobileNet(graphModel, outputStride);

  const validInputResolution = getValidInputResolutionDimensions(inputResolution, outputStride);

  return new PoseNet(mobilenet, validInputResolution);
};

export default loadMobileNet;
