#!/usr/bin/python3


r'''Converts a camera model from one lens model to another

SYNOPSIS

  $ mrcal-convert-lensmodel
      --viz LENSMODEL_OPENCV4 left.cameramodel
      > left.opencv4.cameramodel

  ... lots of output as the solve runs ...
  RMS error of the solution: 3.40256580058 pixels.

  ... a plot pops up showing the differences ...

Given a camera model, this tool computes another model that represents the same
lens, but using a different lens model. While lens models all exist to solve the
same problem, the different representations don't map to one another perfectly,
and this tool seeks to find the best-fitting parameters of the target lens
model. Two different methods are implemented:

1. If the given cameramodel file contains optimization_inputs, then we have all
   the data that was used to compute this model in the first place, and we can
   re-run the original optimization, using the new lens model. This is the
   default behavior. If the input model doesn't have optimization_inputs, an
   error will result, and the other method must be selected by passing --sampled

2. We can sample lots of points on the imager, unproject them to observation
   vectors in the camera coordinate system, and then fit a new camera model that
   reprojects these vectors as closely to the original pixel coordinates as
   possible. Select this mode by passing --sampled.

The first method is preferred. Since camera models (lens parameters AND
geometry) are computed off real pixel observations, the confidence of the final
projections varies greatly, depending on the location of the points being
projected. The first method uses the original data, so it implicitly respects
these uncertainties 100%: low-data areas in the original model will also be
low-data areas in the new model. The second method, however, doesn't have this
information: it doesn't know which parts of the imager are reliable and which
aren't, so the results won't be as good.

As always, the intrinsics have some baked-in geometry information. Both methods
optimize intrinsics AND extrinsics, and output cameramodels with updated
versions of both. If --sampled, then we can request that only the intrinsics be
optimized by passing --intrinsics-only. Also, if --sampled then we fit the
extrinsics off 3D points, not just observation directions. The distance from the
camera to the fitting points is set by --distance. Set this to the distance
where you expect the intrinsics to have the most accuracy. This is only needed
if --sampled and not --intrinsics-only.

If --sampled, we need to consider that the model we're trying to fit may not fit
the original model in all parts of the imager. Usually this is a factor when
converting wide-angle cameramodels to use a leaner model: a decent fit will be
possible at the center, with more and more divergence as we move towards the
edges. We handle this with the --where and --radius options to allow the user to
choose the area of the imager that is used for the fit. This region is centered
on the point given by --where (or at the center of the imager, if omitted). The
radius of this region is given by --radius. If '--radius 0' then I use ALL the
data. A radius<0 can be used to set the size of the no-data margin at the
corners; in this case I'll use

    r = sqrt(width^2 + height^2)/2. - abs(radius)

There's a balance to strike here. A larger radius means that we'll try to fit as
well as we can in a larger area. This might mean that we won't fit well
anywhere, but we won't do terribly anywhere, either. A smaller area means that
we give up on the outer regions entirely (resulting in very poor fits there),
but we'll be able to fit much better in the areas that remain. Generally
empirical testing is required to find a good compromise: pass --viz to see the
resulting differences. Note that --radius and --where applies only if we're
optimizing sampled reprojections; if we're using the original optimization
inputs, the options are illegal.

'''



import sys
import argparse
import re
import os

def parse_args():

    parser = \
        argparse.ArgumentParser(description = __doc__,
                                formatter_class=argparse.RawDescriptionHelpFormatter)

    parser.add_argument('--sampled',
                        action='store_true',
                        help='''Instead of solving the original calibration problem using the new lens model,
                        use sampled imager points. This produces biased results,
                        but can be used even if the original optimization_inputs
                        aren't available''')
    parser.add_argument('--gridn',
                        type=int,
                        default = (30,20),
                        nargs = 2,
                        help='''Used if --sampled. How densely we should sample the imager. By default we use
                        a 30x20 grid''')
    parser.add_argument('--distance',
                        type=float,
                        help='''Used (and required) if --sampled and not
                        --intrinsics-only. A sampled solve fits the intrinsics
                        and extrinsics to match up reprojections of a grid of
                        observed pixels. The points being projected are set a
                        particular distance (set by this argument) from the
                        camera. Set this to the distance that is expected to be
                        most confident for the given cameramodel. Points at
                        infinity aren't explicitly supported: specify a high
                        distance instead''')
    parser.add_argument('--intrinsics-only',
                        action='store_true',
                        help='''Used if --sampled. By default I optimize the
                        intrinsics and extrinsics to find the closest
                        reprojection. If for whatever reason we know that the
                        camera coordinate system was already right, or we need
                        to keep the original extrinsics, pass --intrinsics-only.
                        The resulting extrinsics will be the same, but the fit
                        will not be as good. In many cases, optimizing
                        extrinsics is required to get a usable fit, so
                        --intrinsics-only may not be an option if accurate
                        results are required.''')

    parser.add_argument('--where',
                        type=float,
                        nargs=2,
                        help='''Used if --sampled. I use a subset of the imager to compute the fit. The
                        active region is a circle centered on this point. If
                        omitted, we will focus on the center of the imager''')
    parser.add_argument('--radius',
                        type=float,
                        help='''Used if --sampled. I use a subset of the imager to compute the fit. The
                        active region is a circle with a radius given by this
                        parameter. If radius == 0, I'll use the whole imager for
                        the fit. If radius < 0, this parameter specifies the
                        width of the region at the corners that I should ignore:
                        I will use sqrt(width^2 + height^2)/2. - abs(radius).
                        This is valid ONLY if we're focusing at the center of
                        the imager. By default I ignore a large-ish chunk area
                        at the corners.''')

    parser.add_argument('--viz',
                        action='store_true',
                        help='''Visualize the differences between the input and output models''')

    parser.add_argument('--num-trials',
                        type    = int,
                        default = 1,
                        help='''If given, run the solve more than once. Useful in case random initialization
                        produces noticeably different results. By default we run
                        just one trial, which hopefully should be enough''')

    parser.add_argument('to',
                        type=str,
                        help='The target lens model')

    parser.add_argument('model',
                        type=str,
                        help='''Input camera model. If "-' is given, we read standard input''')

    return parser.parse_args()

args = parse_args()

# arg-parsing is done before the imports so that --help works without building
# stuff, so that I can generate the manpages and README



import numpy as np
import numpysane as nps
import time
import copy

import mrcal

lensmodel_to = args.to

try:
    meta = mrcal.lensmodel_metadata_and_config(lensmodel_to)
except:
    print(f"Invalid lens model '{lensmodel_to}': couldn't get the metadata",
          file=sys.stderr)
    sys.exit(1)
if not meta['has_gradients']:
    print(f"lens model {lensmodel_to} is not supported at this time: its gradients aren't implemented",
          file=sys.stderr)
    sys.exit(1)

try:
    Ndistortions = mrcal.lensmodel_num_params(lensmodel_to) - 4
except:
    print(f"Unknown lens model: '{lensmodel_to}'", file=sys.stderr)
    sys.exit(1)


m = mrcal.cameramodel(args.model)

intrinsics_from = m.intrinsics()
lensmodel_from = intrinsics_from[0]

if lensmodel_from == lensmodel_to:
    sys.stderr.write("Input and output have the same lens model: {}. Returning the input\n".format(lensmodel_to))
    sys.stderr.write("RMS error of the solution: 0 pixels.\n")
    m.write(sys.stdout)
    sys.exit(0)


if not args.sampled:

    if args.where is not None or args.radius is not None:
        print("--where and --radius only make sense with --sampled", file=sys.stderr)
        sys.exit(1)
    if args.distance is not None:
        raise Exception("--distance requires --sampled")
    if args.intrinsics_only:
        raise Exception("--intrinsics-only requires --sampled")
    if not mrcal.lensmodel_metadata_and_config(lensmodel_from)['has_core']:
        raise Exception("Without --sampled, the TO models must support contain an intrinsics core. It COULD work otherwise, but somebody needs to implement it")

    optimization_inputs = m.optimization_inputs()
    if optimization_inputs is None:
        raise Exception("optimization_inputs not available in this model, so only sampled fits are possible. Pass --sampled")


    intrinsics_core = intrinsics_from[1][:4]
    distortions     = (np.random.rand(Ndistortions) - 0.5) * 1e-3
    intrinsics_to_values = nps.dummy(nps.glue(intrinsics_core, distortions, axis=-1),
                                     axis=-2)

    optimization_inputs['lensmodel']  = lensmodel_to
    optimization_inputs['intrinsics'] = intrinsics_to_values
    optimization_inputs['verbose']    = False

    if re.match("LENSMODEL_SPLINED_STEREOGRAPHIC_", lensmodel_to):
        # splined models have a core, but those variables are largely redundant
        # with the spline parameters. So I lock down the core when targetting
        # splined models
        optimization_inputs['do_optimize_intrinsics_core'] = False

    optimization_inputs_before = copy.deepcopy(optimization_inputs)

    stats = mrcal.optimize(**optimization_inputs)


    # I pick the inlier set using the post-solve inliers/outliers. The solve
    # could add some outliers, but it won't take any away, so the union of the
    # before/after inlier sets is just the post-solve set
    idx_inliers1 = optimization_inputs['observations_board'][...,2] > 0

    calobject0 = \
        mrcal.hypothesis_board_corner_positions(m.icam_intrinsics(),
                                                **optimization_inputs_before,
                                                idx_inliers = idx_inliers1 )[-2]
    calobject1 = \
        mrcal.hypothesis_board_corner_positions(m.icam_intrinsics(),
                                                **optimization_inputs,
                                                idx_inliers = idx_inliers1 )[-2]

    implied_Rt10 = mrcal.align_procrustes_points_Rt01(calobject1, calobject0)

    # And for the purposes of visualization, it'd be nice to have a reference
    # distance. I use the mean observation distance
    args.distance = np.mean(nps.mag(calobject1))

    sys.stderr.write("RMS error of the solution: {} pixels.\n". \
                     format(stats['rms_reproj_error__pixels']))
    m_to = mrcal.cameramodel( optimization_inputs = optimization_inputs,
                              icam_intrinsics     = m.icam_intrinsics() )

else:

    # Sampled solve. I grid the imager, unproject and fit
    if args.distance is None:
        if args.intrinsics_only:
            # We're only looking at the intrinsics, so the distance doesn't
            # matter. I pick an arbitrary value
            args.distance = 10.
        else:
            raise Exception("--sampled without --intrinsics-only REQUIRES --distance")

    dims = m.imagersize()
    if dims is None:
        sys.stderr.write("Warning: imager size not available. Using centerpixel*2\n")
        dims = intrinsics_from[1][2:4] * 2

    if args.radius is None:
        # By default use 1/4 of the smallest dimension
        args.radius = -np.min(m.imagersize()) // 4
        sys.stderr.write("Default radius: {}. We're ignoring the regions {} pixels from each corner\n". \
                         format(args.radius, -args.radius))
        if args.where is not None and \
           nps.norm2(args.where - (dims - 1.) / 2) > 1e-3:
            sys.stderr.write("A radius <0 is only implemented if we're focusing on the imager center: use an explicit --radius, or omit --where\n")
            sys.exit(1)


    # Alrighty. Let's actually do the work. I do this:
    #
    # 1. Sample the imager space with the known model
    # 2. Unproject to get the 3d observation vectors
    # 3. Solve a new model that fits those vectors to the known observations, but
    #    using the new model

    ### I sample the pixels in an NxN grid
    Nx,Ny = args.gridn

    qx = np.linspace(0, dims[0]-1, Nx)
    qy = np.linspace(0, dims[1]-1, Ny)

    # q is (Ny*Nx, 2). Each slice of q[:] is an (x,y) pixel coord
    q = np.ascontiguousarray( nps.transpose(nps.clump( nps.cat(*np.meshgrid(qx,qy)), n=-2)) )
    if args.radius != 0:
        # we use a subset of the input data for the fit
        if args.where is None:
            focus_center = (dims - 1.) / 2.
        else:
            focus_center = args.where

        if args.radius > 0:
            r = args.radius
        else:
            if nps.norm2(focus_center - (dims - 1.) / 2) > 1e-3:
                sys.stderr.write("A radius <0 is only implemented if we're focusing on the imager center\n")
                sys.exit(1)
            r = nps.mag(dims)/2. + args.radius

        grid_off_center = q - focus_center
        i = nps.norm2(grid_off_center) < r*r
        q = q[i, ...]


    # To visualize the sample grid:
    # import gnuplotlib as gp
    # gp.plot(q[:,0], q[:,1], _with='points pt 7 ps 2', xrange=[0,3904],yrange=[3904,0], wait=1, square=1)
    # sys.exit()

    ### I unproject this, with broadcasting
    v = mrcal.unproject( q, *intrinsics_from, normalize = True ) * \
        args.distance

    # Ignore any failed unprojections
    i_finite = np.isfinite(v[:,0])
    v = v[i_finite]
    q = q[i_finite]
    Npoints = len(q)
    weights = np.ones((Npoints,), dtype=float)

    ### Solve!

    ### I solve the optimization a number of times with different random seed
    ### values, taking the best-fitting results. This is required for the richer
    ### models such as LENSMODEL_OPENCV8
    err_rms_best               = 1e10
    intrinsics_data_best       = None
    extrinsics_rt_fromref_best = None

    for i in range(args.num_trials):
        # random seed for the new intrinsics
        intrinsics_core = intrinsics_from[1][:4]
        distortions     = (np.random.rand(Ndistortions) - 0.5) * 1e-3 # random initial seed
        intrinsics_to_values = nps.dummy(nps.glue(intrinsics_core, distortions, axis=-1),
                                         axis=-2)
        # each point has weight 1.0
        observations_points = nps.glue(q, nps.transpose(weights), axis=-1)
        observations_points = np.ascontiguousarray(observations_points) # must be contiguous. mrcal.optimize() should really be more lax here

        # Which points we're observing. This is dense and kinda silly for this
        # application. Each slice is (i_point,i_camera)
        indices_point_camintrinsics_camextrinsics = np.zeros((Npoints,3), dtype=np.int32)
        indices_point_camintrinsics_camextrinsics[:,0] = \
            np.arange(Npoints,    dtype=np.int32)
        indices_point_camintrinsics_camextrinsics[:,1] = \
            np.zeros ((Npoints,), dtype=np.int32)
        indices_point_camintrinsics_camextrinsics[:,2] = \
            np.zeros ((Npoints,), dtype=np.int32) - 1

        optimization_inputs = \
            dict(intrinsics                                = intrinsics_to_values,
                 extrinsics_rt_fromref                     = None,
                 frames_rt_toref                           = None, # no frames. Just points
                 points                                    = v,
                 observations_board                        = None, # no board observations
                 indices_frame_camintrinsics_camextrinsics = None, # no board observations
                 observations_point                        = observations_points,
                 indices_point_camintrinsics_camextrinsics = indices_point_camintrinsics_camextrinsics,
                 lensmodel                                 = lensmodel_to,

                 imagersizes                               = nps.atleast_dims(dims, -2),

                 # I'm not optimizing the point positions (frames), so these
                 # need to be set to be inactive, and to include the ranges I do
                 # have
                 point_min_range                           = args.distance * 0.9,
                 point_max_range                           = args.distance * 1.1,

                 # I optimize the lens parameters. That's the whole point
                 do_optimize_intrinsics_core               = True,
                 do_optimize_intrinsics_distortions        = True,

                 do_optimize_extrinsics                    = False,

                 # NOT optimizing the observed point positions
                 do_optimize_frames                        = False )

        if re.match("LENSMODEL_SPLINED_STEREOGRAPHIC_", lensmodel_to):
            # splined models have a core, but those variables are largely redundant
            # with the spline parameters. So I lock down the core when targetting
            # splined models
            optimization_inputs['do_optimize_intrinsics_core'] = False

        stats = mrcal.optimize(**optimization_inputs,
                               # No outliers. I have the points that I have
                               do_apply_outlier_rejection        = False,
                               verbose                           = False)

        if not args.intrinsics_only:

            # go again, but refine this solution, allowing us to fit the
            # extrinsics too
            optimization_inputs['indices_point_camintrinsics_camextrinsics'][:,2] = \
                np.zeros ((Npoints,), dtype = np.int32)
            optimization_inputs['extrinsics_rt_fromref']  = (np.random.rand(1,6) - 0.5) * 1e-6
            optimization_inputs['do_optimize_extrinsics'] = True

            stats = mrcal.optimize(**optimization_inputs,
                                   # No outliers. I have the points that I have
                                   do_apply_outlier_rejection        = False,
                                   verbose                           = False)

        err_rms = stats['rms_reproj_error__pixels']
        print(f"RMS error of this solution: {err_rms} pixels.",
              file=sys.stderr)
        if err_rms < err_rms_best:
            err_rms_best = err_rms
            intrinsics_data_best  = optimization_inputs['intrinsics'][0,:].copy()
            if not args.intrinsics_only:
                extrinsics_rt_fromref_best = optimization_inputs['extrinsics_rt_fromref'][0,:].copy()

    if intrinsics_data_best is None:
        print("No valid intrinsics found!", file=sys.stderr)
        sys.exit(1)

    sys.stderr.write("RMS error of the BEST solution: {} pixels.\n".format(err_rms_best))

    m_to = mrcal.cameramodel( intrinsics = (lensmodel_to, intrinsics_data_best.ravel()),
                              imagersize = dims )

    if args.intrinsics_only:
        implied_Rt10 = mrcal.identity_Rt()
    else:
        implied_Rt10 = mrcal.Rt_from_rt(extrinsics_rt_fromref_best)


m_to.extrinsics_Rt_fromref( mrcal.compose_Rt( implied_Rt10,
                                              mrcal.Rt_from_rt( m.extrinsics_rt_fromref())))

note = \
    "generated on {} with   {}\n". \
    format(time.strftime("%Y-%m-%d %H:%M:%S"),
           ' '.join(mrcal.shellquote(s) for s in sys.argv))

m_to.write(sys.stdout, note=note)

if args.viz:
    plot,_ = mrcal.show_projection_diff( (m, m_to),
                                         distance     = args.distance,
                                         cbmax        = 4,
                                         gridn_width  = None if args.gridn is None else args.gridn[0],
                                         gridn_height = None if args.gridn is None else args.gridn[1] )
    plot.wait()

