Source code for yatsm.cli.pixel

""" Command line interface for running YATSM algorithms on individual pixels
"""
import datetime as dt
import logging
import re

import click
import matplotlib as mpl
import matplotlib.cm  # noqa
import matplotlib.pyplot as plt
import numpy as np
import patsy
import yaml

from . import options, console
from ..algorithms import postprocess
from ..config_parser import convert_config, parse_config_file
from ..io import read_pixel_timeseries
from ..utils import csvfile_to_dataframe, get_image_IDs
from ..regression.transforms import harm  # noqa

avail_plots = ['TS', 'DOY', 'VAL']

SEASONS = {
    'winter': ((11, 12, 1, 2, 3), 'b', 0.5),
    'spring': ((4, 5), 'c', 0.5),
    'summer': ((6, 7, 8), 'g', 1.,),
    'fall': ((9, 10), 'y', 0.5)
}

_DEFAULT_PLOT_CMAP = ('viridis', 'cubehelix', 'jet')
PLOT_CMAP = _DEFAULT_PLOT_CMAP[-1]
for _cmap in _DEFAULT_PLOT_CMAP:
    if _cmap in mpl.cm.cmap_d:
        PLOT_CMAP = _cmap
        break

plot_styles = []
if hasattr(mpl, 'style'):
    plot_styles = mpl.style.available
if hasattr(plt, 'xkcd'):
    plot_styles.append('xkcd')

logger = logging.getLogger('yatsm')


@click.command(short_help='Run YATSM algorithm on individual pixels')
@options.arg_config_file
@click.argument('px', metavar='<px>', nargs=1, type=click.INT)
@click.argument('py', metavar='<py>', nargs=1, type=click.INT)
@click.option('--band', metavar='<n>', nargs=1, type=click.INT, default=1,
              show_default=True, help='Band to plot')
@click.option('--plot', default=('TS',), multiple=True, show_default=True,
              type=click.Choice(avail_plots), help='Plot type')
@click.option('--ylim', metavar='<min> <max>', nargs=2, type=float,
              show_default=True, help='Y-axis limits')
@click.option('--style', metavar='<style>', default='ggplot',
              show_default=True, type=click.Choice(plot_styles),
              help='Plot style')
@click.option('--cmap', metavar='<cmap>', default=PLOT_CMAP,
              show_default=True, help='DOY/VAL plot colormap')
@click.option('--embed', is_flag=True,
              help='Drop to (I)Python interpreter at various points')
@click.option('--seed', help='Set NumPy RNG seed value')
@click.option('--algo_kw', multiple=True, callback=options.callback_dict,
              help='Algorithm parameter overrides')
@click.option('--result_prefix', type=str, default='', show_default=True,
              multiple=True,
              help='Plot coef/rmse from refit that used this prefix')
@click.option('--seasons', is_flag=True, help='Plot using seasonal symbology')
@click.pass_context
def pixel(ctx, config, px, py, band, plot, ylim, style, cmap,
          embed, seed, algo_kw, result_prefix, seasons):
    # Set seed
    np.random.seed(seed)
    # Convert band to index
    band -= 1
    # Format result prefix
    if result_prefix:
        result_prefix = set((_pref if _pref[-1] == '_' else _pref + '_')
                            for _pref in result_prefix)
        result_prefix.add('')  # add in no prefix to show original fit
    else:
        result_prefix = ('', )

    # Get colormap
    if cmap not in mpl.cm.cmap_d:
        raise click.ClickException('Cannot find specified colormap ({}) in '
                                   'matplotlib'.format(cmap))

    # Parse config
    cfg = parse_config_file(config)

    # Apply algorithm overrides
    for kw in algo_kw:
        value = yaml.load(algo_kw[kw])
        cfg = trawl_replace_keys(cfg, kw, value)
    if algo_kw:  # revalidate configuration
        cfg = convert_config(cfg)

    # Dataset information
    df = csvfile_to_dataframe(cfg['dataset']['input_file'],
                              date_format=cfg['dataset']['date_format'])
    df['image_ID'] = get_image_IDs(df['filename'])
    df['x'] = df['date']
    dates = df['date'].values

    # Initialize timeseries model
    model = cfg['YATSM']['algorithm_object']
    algo_cfg = cfg[cfg['YATSM']['algorithm']]
    yatsm = model(estimator=cfg['YATSM']['estimator'],
                  **algo_cfg.get('init', {}))
    yatsm.px = px
    yatsm.py = py

    # Setup algorithm and create design matrix (if needed)
    X = yatsm.setup(df, **cfg)
    design_info = getattr(X, 'design_info', None)

    # Read pixel data
    Y = read_pixel_timeseries(df['filename'], px, py)
    if Y.shape[0] != cfg['dataset']['n_bands']:
        raise click.ClickException(
            'Number of bands in image {f} ({nf}) do not match number in '
            'configuration file ({nc})'.format(
                f=df['filename'][0],
                nf=Y.shape[0],
                nc=cfg['dataset']['n_bands']))

    # Preprocess pixel data
    X, Y, dates = yatsm.preprocess(X, Y, dates, **cfg['dataset'])

    # Convert ordinal to datetime
    dt_dates = np.array([dt.datetime.fromordinal(d) for d in dates])

    # Plot before fitting
    with plt.xkcd() if style == 'xkcd' else mpl.style.context(style):
        for _plot in plot:
            if _plot == 'TS':
                plot_TS(dt_dates, Y[band, :], seasons)
            elif _plot == 'DOY':
                plot_DOY(dt_dates, Y[band, :], cmap)
            elif _plot == 'VAL':
                plot_VAL(dt_dates, Y[band, :], cmap)

            if ylim:
                plt.ylim(ylim)
            plt.title('Timeseries: px={px} py={py}'.format(px=px, py=py))
            plt.ylabel('Band {b}'.format(b=band + 1))
            plt.tight_layout()
            plt.show()

    # Fit model
    yatsm.fit(X, Y, dates, **algo_cfg.get('fit', {}))
    for prefix, estimator, stay_reg, fitopt in zip(
            cfg['YATSM']['refit']['prefix'],
            cfg['YATSM']['refit']['prediction_object'],
            cfg['YATSM']['refit']['stay_regularized'],
            cfg['YATSM']['refit']['fit']):
        yatsm.record = postprocess.refit_record(
            yatsm, prefix, estimator,
            fitopt=fitopt, keep_regularized=stay_reg)

    # Plot after predictions
    with plt.xkcd() if style == 'xkcd' else mpl.style.context(style):
            for _plot in plot:
                if _plot == 'TS':
                    plot_TS(dt_dates, Y[band, :], seasons)
                elif _plot == 'DOY':
                    plot_DOY(dt_dates, Y[band, :], cmap)
                elif _plot == 'VAL':
                    plot_VAL(dt_dates, Y[band, :], cmap)

                if ylim:
                    plt.ylim(ylim)
                plt.title('Timeseries: px={px} py={py}'.format(px=px, py=py))
                plt.ylabel('Band {b}'.format(b=band + 1))

                for _prefix in set(result_prefix):
                    plot_results(band, cfg, yatsm, design_info,
                                 result_prefix=_prefix,
                                 plot_type=_plot)

                plt.tight_layout()
                plt.show()

    if embed:
        console.open_interpreter(
            yatsm,
            message=("Additional functions:\n"
                     "plot_TS, plot_DOY, plot_VAL, plot_results"),
            variables={
                'config': cfg,
            },
            funcs={
                'plot_TS': plot_TS, 'plot_DOY': plot_DOY,
                'plot_VAL': plot_VAL, 'plot_results': plot_results
            }
        )


[docs]def plot_TS(dates, y, seasons): """ Create a standard timeseries plot Args: dates (iterable): sequence of datetime y (np.ndarray): variable to plot seasons (bool): Plot seasonal symbology """ # Plot data if seasons: months = np.array([d.month for d in dates]) for season_months, color, alpha in SEASONS.values(): season_idx = np.in1d(months, season_months) plt.plot(dates[season_idx], y[season_idx], marker='o', mec=color, mfc=color, alpha=alpha, ls='') else: plt.scatter(dates, y, c='k', marker='o', edgecolors='none', s=35) plt.xlabel('Date')
[docs]def plot_DOY(dates, y, mpl_cmap): """ Create a DOY plot Args: dates (iterable): sequence of datetime y (np.ndarray): variable to plot mpl_cmap (colormap): matplotlib colormap """ doy = np.array([d.timetuple().tm_yday for d in dates]) year = np.array([d.year for d in dates]) sp = plt.scatter(doy, y, c=year, cmap=mpl_cmap, marker='o', edgecolors='none', s=35) plt.colorbar(sp) months = mpl.dates.MonthLocator() # every month months_fmrt = mpl.dates.DateFormatter('%b') plt.tick_params(axis='x', which='minor', direction='in', pad=-10) plt.axes().xaxis.set_minor_locator(months) plt.axes().xaxis.set_minor_formatter(months_fmrt) plt.xlim(1, 366) plt.xlabel('Day of Year')
[docs]def plot_VAL(dates, y, mpl_cmap, reps=2): """ Create a "Valerie Pasquarella" plot (repeated DOY plot) Args: dates (iterable): sequence of datetime y (np.ndarray): variable to plot mpl_cmap (colormap): matplotlib colormap reps (int, optional): number of additional repetitions """ doy = np.array([d.timetuple().tm_yday for d in dates]) year = np.array([d.year for d in dates]) # Replicate `reps` times _doy = doy.copy() for r in range(1, reps + 1): _doy = np.concatenate((_doy, doy + r * 366)) _year = np.tile(year, reps + 1) _y = np.tile(y, reps + 1) sp = plt.scatter(_doy, _y, c=_year, cmap=mpl_cmap, marker='o', edgecolors='none', s=35) plt.colorbar(sp) plt.xlabel('Day of Year')
[docs]def plot_results(band, cfg, model, design_info, result_prefix='', plot_type='TS'): """ Plot model results Args: band (int): plot results for this band cfg (dict): YATSM configuration dictionary model (YATSM model): fitted YATSM timeseries model design_info (patsy.DesignInfo): patsy design information result_prefix (str): Prefix to 'coef' and 'rmse' plot_type (str): type of plot to add results to (TS, DOY, or VAL) """ # Results prefix result_k = model.record.dtype.names coef_k = result_prefix + 'coef' rmse_k = result_prefix + 'rmse' if coef_k not in result_k or rmse_k not in result_k: raise KeyError('Cannot find result prefix "{}" in results' .format(result_prefix)) if result_prefix: click.echo('Using "{}" re-fitted results'.format(result_prefix)) # Handle reverse step = -1 if cfg['YATSM']['reverse'] else 1 # Remove categorical info from predictions design = re.sub(r'[\+\-][\ ]+C\(.*\)', '', cfg['YATSM']['design_matrix']) i_coef = [] for k, v in design_info.column_name_indexes.items(): if not re.match('C\(.*\)', k): i_coef.append(v) i_coef = np.sort(np.asarray(i_coef)) _prefix = result_prefix or cfg['YATSM']['prediction'] for i, r in enumerate(model.record): label = 'Model {i} ({prefix})'.format(i=i, prefix=_prefix) if plot_type == 'TS': # Prediction mx = np.arange(r['start'], r['end'], step) mX = patsy.dmatrix(design, {'x': mx}).T my = np.dot(r[coef_k][i_coef, band], mX) mx_date = np.array([dt.datetime.fromordinal(int(_x)) for _x in mx]) # Break if r['break']: bx = dt.datetime.fromordinal(r['break']) plt.axvline(bx, c='red', lw=2) elif plot_type in ('DOY', 'VAL'): yr_end = dt.datetime.fromordinal(r['end']).year yr_start = dt.datetime.fromordinal(r['start']).year yr_mid = int(yr_end - (yr_end - yr_start) / 2) mx = np.arange(dt.date(yr_mid, 1, 1).toordinal(), dt.date(yr_mid + 1, 1, 1).toordinal(), 1) mX = patsy.dmatrix(design, {'x': mx}).T my = np.dot(r[coef_k][i_coef, band], mX) mx_date = np.array([dt.datetime.fromordinal(d).timetuple().tm_yday for d in mx]) label = 'Model {i} - {yr} ({prefix})'.format(i=i, yr=yr_mid, prefix=_prefix) plt.plot(mx_date, my, lw=3, label=label) leg = plt.legend() leg.draggable(state=True)
# UTILITY FUNCTIONS
[docs]def trawl_replace_keys(d, key, value, s=''): """ Return modified dictionary ``d`` """ md = d.copy() for _key in md: if isinstance(md[_key], dict): # Recursively replace md[_key] = trawl_replace_keys(md[_key], key, value, s='{}[{}]'.format(s, _key)) else: if _key == key: s += '[{}]'.format(_key) click.echo('Replacing d{k}={ov} with {nv}' .format(k=s, ov=md[_key], nv=value)) md[_key] = value return md