Source code for yatsm.cli.pixel

""" Command line interface for running YATSM algorithms on individual pixels
import datetime as dt
import logging
import re

import click
import matplotlib as mpl
import  # noqa
import matplotlib.pyplot as plt
import numpy as np
import patsy
import yaml

from . import options, console
from ..algorithms import postprocess
from ..config_parser import convert_config, parse_config_file
from import read_pixel_timeseries
from ..utils import csvfile_to_dataframe, get_image_IDs
from ..regression.transforms import harm  # noqa

avail_plots = ['TS', 'DOY', 'VAL']

    'winter': ((11, 12, 1, 2, 3), 'b', 0.5),
    'spring': ((4, 5), 'c', 0.5),
    'summer': ((6, 7, 8), 'g', 1.,),
    'fall': ((9, 10), 'y', 0.5)

_DEFAULT_PLOT_CMAP = ('viridis', 'cubehelix', 'jet')
for _cmap in _DEFAULT_PLOT_CMAP:
    if _cmap in
        PLOT_CMAP = _cmap

plot_styles = []
if hasattr(mpl, 'style'):
    plot_styles =
if hasattr(plt, 'xkcd'):

logger = logging.getLogger('yatsm')

@click.command(short_help='Run YATSM algorithm on individual pixels')
@click.argument('px', metavar='<px>', nargs=1, type=click.INT)
@click.argument('py', metavar='<py>', nargs=1, type=click.INT)
@click.option('--band', metavar='<n>', nargs=1, type=click.INT, default=1,
              show_default=True, help='Band to plot')
@click.option('--plot', default=('TS',), multiple=True, show_default=True,
              type=click.Choice(avail_plots), help='Plot type')
@click.option('--ylim', metavar='<min> <max>', nargs=2, type=float,
              show_default=True, help='Y-axis limits')
@click.option('--style', metavar='<style>', default='ggplot',
              show_default=True, type=click.Choice(plot_styles),
              help='Plot style')
@click.option('--cmap', metavar='<cmap>', default=PLOT_CMAP,
              show_default=True, help='DOY/VAL plot colormap')
@click.option('--embed', is_flag=True,
              help='Drop to (I)Python interpreter at various points')
@click.option('--seed', help='Set NumPy RNG seed value')
@click.option('--algo_kw', multiple=True, callback=options.callback_dict,
              help='Algorithm parameter overrides')
@click.option('--result_prefix', type=str, default='', show_default=True,
              help='Plot coef/rmse from refit that used this prefix')
@click.option('--seasons', is_flag=True, help='Plot using seasonal symbology')
def pixel(ctx, config, px, py, band, plot, ylim, style, cmap,
          embed, seed, algo_kw, result_prefix, seasons):
    # Set seed
    # Convert band to index
    band -= 1
    # Format result prefix
    if result_prefix:
        result_prefix = set((_pref if _pref[-1] == '_' else _pref + '_')
                            for _pref in result_prefix)
        result_prefix.add('')  # add in no prefix to show original fit
        result_prefix = ('', )

    # Get colormap
    if cmap not in
        raise click.ClickException('Cannot find specified colormap ({}) in '

    # Parse config
    cfg = parse_config_file(config)

    # Apply algorithm overrides
    for kw in algo_kw:
        value = yaml.load(algo_kw[kw])
        cfg = trawl_replace_keys(cfg, kw, value)
    if algo_kw:  # revalidate configuration
        cfg = convert_config(cfg)

    # Dataset information
    df = csvfile_to_dataframe(cfg['dataset']['input_file'],
    df['image_ID'] = get_image_IDs(df['filename'])
    df['x'] = df['date']
    dates = df['date'].values

    # Initialize timeseries model
    model = cfg['YATSM']['algorithm_object']
    algo_cfg = cfg[cfg['YATSM']['algorithm']]
    yatsm = model(estimator=cfg['YATSM']['estimator'],
                  **algo_cfg.get('init', {}))
    yatsm.px = px = py

    # Setup algorithm and create design matrix (if needed)
    X = yatsm.setup(df, **cfg)
    design_info = getattr(X, 'design_info', None)

    # Read pixel data
    Y = read_pixel_timeseries(df['filename'], px, py)
    if Y.shape[0] != cfg['dataset']['n_bands']:
        raise click.ClickException(
            'Number of bands in image {f} ({nf}) do not match number in '
            'configuration file ({nc})'.format(

    # Preprocess pixel data
    X, Y, dates = yatsm.preprocess(X, Y, dates, **cfg['dataset'])

    # Convert ordinal to datetime
    dt_dates = np.array([dt.datetime.fromordinal(d) for d in dates])

    # Plot before fitting
    with plt.xkcd() if style == 'xkcd' else
        for _plot in plot:
            if _plot == 'TS':
                plot_TS(dt_dates, Y[band, :], seasons)
            elif _plot == 'DOY':
                plot_DOY(dt_dates, Y[band, :], cmap)
            elif _plot == 'VAL':
                plot_VAL(dt_dates, Y[band, :], cmap)

            if ylim:
            plt.title('Timeseries: px={px} py={py}'.format(px=px, py=py))
            plt.ylabel('Band {b}'.format(b=band + 1))

    # Fit model, Y, dates, **algo_cfg.get('fit', {}))
    for prefix, estimator, stay_reg, fitopt in zip(
        yatsm.record = postprocess.refit_record(
            yatsm, prefix, estimator,
            fitopt=fitopt, keep_regularized=stay_reg)

    # Plot after predictions
    with plt.xkcd() if style == 'xkcd' else
            for _plot in plot:
                if _plot == 'TS':
                    plot_TS(dt_dates, Y[band, :], seasons)
                elif _plot == 'DOY':
                    plot_DOY(dt_dates, Y[band, :], cmap)
                elif _plot == 'VAL':
                    plot_VAL(dt_dates, Y[band, :], cmap)

                if ylim:
                plt.title('Timeseries: px={px} py={py}'.format(px=px, py=py))
                plt.ylabel('Band {b}'.format(b=band + 1))

                for _prefix in set(result_prefix):
                    plot_results(band, cfg, yatsm, design_info,


    if embed:
            message=("Additional functions:\n"
                     "plot_TS, plot_DOY, plot_VAL, plot_results"),
                'config': cfg,
                'plot_TS': plot_TS, 'plot_DOY': plot_DOY,
                'plot_VAL': plot_VAL, 'plot_results': plot_results

[docs]def plot_TS(dates, y, seasons): """ Create a standard timeseries plot Args: dates (iterable): sequence of datetime y (np.ndarray): variable to plot seasons (bool): Plot seasonal symbology """ # Plot data if seasons: months = np.array([d.month for d in dates]) for season_months, color, alpha in SEASONS.values(): season_idx = np.in1d(months, season_months) plt.plot(dates[season_idx], y[season_idx], marker='o', mec=color, mfc=color, alpha=alpha, ls='') else: plt.scatter(dates, y, c='k', marker='o', edgecolors='none', s=35) plt.xlabel('Date')
[docs]def plot_DOY(dates, y, mpl_cmap): """ Create a DOY plot Args: dates (iterable): sequence of datetime y (np.ndarray): variable to plot mpl_cmap (colormap): matplotlib colormap """ doy = np.array([d.timetuple().tm_yday for d in dates]) year = np.array([d.year for d in dates]) sp = plt.scatter(doy, y, c=year, cmap=mpl_cmap, marker='o', edgecolors='none', s=35) plt.colorbar(sp) months = mpl.dates.MonthLocator() # every month months_fmrt = mpl.dates.DateFormatter('%b') plt.tick_params(axis='x', which='minor', direction='in', pad=-10) plt.axes().xaxis.set_minor_locator(months) plt.axes().xaxis.set_minor_formatter(months_fmrt) plt.xlim(1, 366) plt.xlabel('Day of Year')
[docs]def plot_VAL(dates, y, mpl_cmap, reps=2): """ Create a "Valerie Pasquarella" plot (repeated DOY plot) Args: dates (iterable): sequence of datetime y (np.ndarray): variable to plot mpl_cmap (colormap): matplotlib colormap reps (int, optional): number of additional repetitions """ doy = np.array([d.timetuple().tm_yday for d in dates]) year = np.array([d.year for d in dates]) # Replicate `reps` times _doy = doy.copy() for r in range(1, reps + 1): _doy = np.concatenate((_doy, doy + r * 366)) _year = np.tile(year, reps + 1) _y = np.tile(y, reps + 1) sp = plt.scatter(_doy, _y, c=_year, cmap=mpl_cmap, marker='o', edgecolors='none', s=35) plt.colorbar(sp) plt.xlabel('Day of Year')
[docs]def plot_results(band, cfg, model, design_info, result_prefix='', plot_type='TS'): """ Plot model results Args: band (int): plot results for this band cfg (dict): YATSM configuration dictionary model (YATSM model): fitted YATSM timeseries model design_info (patsy.DesignInfo): patsy design information result_prefix (str): Prefix to 'coef' and 'rmse' plot_type (str): type of plot to add results to (TS, DOY, or VAL) """ # Results prefix result_k = model.record.dtype.names coef_k = result_prefix + 'coef' rmse_k = result_prefix + 'rmse' if coef_k not in result_k or rmse_k not in result_k: raise KeyError('Cannot find result prefix "{}" in results' .format(result_prefix)) if result_prefix: click.echo('Using "{}" re-fitted results'.format(result_prefix)) # Handle reverse step = -1 if cfg['YATSM']['reverse'] else 1 # Remove categorical info from predictions design = re.sub(r'[\+\-][\ ]+C\(.*\)', '', cfg['YATSM']['design_matrix']) i_coef = [] for k, v in design_info.column_name_indexes.items(): if not re.match('C\(.*\)', k): i_coef.append(v) i_coef = np.sort(np.asarray(i_coef)) _prefix = result_prefix or cfg['YATSM']['prediction'] for i, r in enumerate(model.record): label = 'Model {i} ({prefix})'.format(i=i, prefix=_prefix) if plot_type == 'TS': # Prediction mx = np.arange(r['start'], r['end'], step) mX = patsy.dmatrix(design, {'x': mx}).T my =[coef_k][i_coef, band], mX) mx_date = np.array([dt.datetime.fromordinal(int(_x)) for _x in mx]) # Break if r['break']: bx = dt.datetime.fromordinal(r['break']) plt.axvline(bx, c='red', lw=2) elif plot_type in ('DOY', 'VAL'): yr_end = dt.datetime.fromordinal(r['end']).year yr_start = dt.datetime.fromordinal(r['start']).year yr_mid = int(yr_end - (yr_end - yr_start) / 2) mx = np.arange(, 1, 1).toordinal(), + 1, 1, 1).toordinal(), 1) mX = patsy.dmatrix(design, {'x': mx}).T my =[coef_k][i_coef, band], mX) mx_date = np.array([dt.datetime.fromordinal(d).timetuple().tm_yday for d in mx]) label = 'Model {i} - {yr} ({prefix})'.format(i=i, yr=yr_mid, prefix=_prefix) plt.plot(mx_date, my, lw=3, label=label) leg = plt.legend() leg.draggable(state=True)
[docs]def trawl_replace_keys(d, key, value, s=''): """ Return modified dictionary ``d`` """ md = d.copy() for _key in md: if isinstance(md[_key], dict): # Recursively replace md[_key] = trawl_replace_keys(md[_key], key, value, s='{}[{}]'.format(s, _key)) else: if _key == key: s += '[{}]'.format(_key) click.echo('Replacing d{k}={ov} with {nv}' .format(k=s, ov=md[_key], nv=value)) md[_key] = value return md