Source code for yatsm.utils

from __future__ import division

from datetime import datetime as dt
import fnmatch
import os
import re
import sys

import numpy as np
import pandas as pd
import six

try:
    from scandir import walk
except:
    from os import walk

from .log_yatsm import logger


# JOB SPECIFIC FUNCTIONS
[docs]def distribute_jobs(job_number, total_jobs, n, interlaced=True): """ Assign `job_number` out of `total_jobs` a subset of `n` tasks Args: job_number (int): 0-indexed processor to distribute jobs to total_jobs (int): total number of processors running jobs n (int): number of tasks (e.g., lines in image, regions in segment) interlaced (bool, optional): interlace job assignment (default: True) Returns: np.ndarray: np.ndarray of task IDs to be processed Raises: ValueError: raise error if `job_number` and `total_jobs` specified result in no jobs being assinged (happens if `job_number` and `total_jobs` are both 1) """ if interlaced: assigned = 0 tasks = [] while job_number + total_jobs * assigned < n: tasks.append(job_number + total_jobs * assigned) assigned += 1 tasks = np.asarray(tasks) else: size = int(n / total_jobs) i_start = size * job_number i_end = size * (job_number + 1) tasks = np.arange(i_start, min(i_end, n)) if tasks.size == 0: raise ValueError( 'No jobs assigned for job_number/total_jobs: {j}/{t}'.format( j=job_number, t=total_jobs)) return tasks
[docs]def get_output_name(dataset_config, line): """ Returns output name for specified config and line number Args: dataset_config (dict): configuration information about the dataset line (int): line of the dataset for output Returns: filename (str): output filename """ return os.path.join(dataset_config['output'], '%s%s.npz' % (dataset_config['output_prefix'], line))
# IMAGE DATASET READING
[docs]def csvfile_to_dataframe(input_file, date_format='%Y%j'): """ Return sorted filenames of images from input text file Args: input_file (str): text file of dates and files date_format (str): format of dates in file Returns: pd.DataFrame: pd.DataFrame of dates, sensor IDs, and filenames """ df = pd.read_csv(input_file) # Guess and convert date field date_col = [i for i, n in enumerate(df.columns) if 'date' in n.lower()] if not date_col: raise KeyError('Could not find date column in input file') if len(date_col) > 1: logger.warning('Multiple date columns found in input CSV file. ' 'Using %s' % df.columns[date_col[0]]) date_col = df.columns[date_col[0]] df[date_col] = pd.to_datetime( df[date_col], format=date_format).map(lambda x: dt.toordinal(x)) return df
[docs]def get_image_IDs(filenames): """ Returns image IDs for each filename (basename of dirname of file) Args: filenames (iterable): filenames to return image IDs for Returns: list: image IDs for each file in `filenames` """ return [os.path.basename(os.path.dirname(f)) for f in filenames]
# MAPPING UTILITIES
[docs]def write_output(raster, output, image_ds, gdal_frmt, ndv, band_names=None): """ Write raster to output file """ from osgeo import gdal, gdal_array logger.debug('Writing output to disk') driver = gdal.GetDriverByName(str(gdal_frmt)) if len(raster.shape) > 2: nband = raster.shape[2] else: nband = 1 ds = driver.Create( output, image_ds.RasterXSize, image_ds.RasterYSize, nband, gdal_array.NumericTypeCodeToGDALTypeCode(raster.dtype.type) ) if band_names is not None: if len(band_names) != nband: logger.error('Did not get enough names for all bands') sys.exit(1) if raster.ndim > 2: for b in range(nband): logger.debug(' writing band {b}'.format(b=b + 1)) ds.GetRasterBand(b + 1).WriteArray(raster[:, :, b]) ds.GetRasterBand(b + 1).SetNoDataValue(ndv) if band_names is not None: ds.GetRasterBand(b + 1).SetDescription(band_names[b]) ds.GetRasterBand(b + 1).SetMetadata({ 'band_{i}'.format(i=b + 1): band_names[b] }) else: logger.debug(' writing band') ds.GetRasterBand(1).WriteArray(raster) ds.GetRasterBand(1).SetNoDataValue(ndv) if band_names is not None: ds.GetRasterBand(1).SetDescription(band_names[0]) ds.GetRasterBand(1).SetMetadata({'band_1': band_names[0]}) ds.SetProjection(image_ds.GetProjection()) ds.SetGeoTransform(image_ds.GetGeoTransform()) ds = None
# RESULT UTILITIES
[docs]def find_results(location, pattern): """ Create list of result files and return sorted Args: location (str): directory location to search pattern (str): glob style search pattern for results Returns: results (list): list of file paths for results found """ # Note: already checked for location existence in main() records = [] for root, dirnames, filenames in walk(location): for filename in fnmatch.filter(filenames, pattern): records.append(os.path.join(root, filename)) if len(records) == 0: raise IOError('Could not find results in: %s' % location) records.sort() return records
[docs]def iter_records(records, warn_on_empty=False, yield_filename=False): """ Iterates over records, returning result NumPy array Args: records (list): List containing filenames of results warn_on_empty (bool, optional): Log warning if result contained no result records (default: False) yield_filename (bool, optional): Yield the filename and the record Yields: np.ndarray or tuple: Result saved in record and the filename, if desired """ n_records = len(records) for _i, r in enumerate(records): # Verbose progress if np.mod(_i, 100) == 0: logger.debug('{0:.1f}%'.format(_i / n_records * 100)) # Open output try: rec = np.load(r)['record'] except (ValueError, AssertionError, IOError) as e: logger.warning('Error reading a result file (may be corrupted) ' '({}): {}'.format(r, str(e))) continue if rec.shape[0] == 0: # No values in this file if warn_on_empty: logger.warning('Could not find results in {f}'.format(f=r)) continue if yield_filename: yield rec, r else: yield rec
# MISC UTILITIES
[docs]def date2index(dates, d): """ Returns index of sorted array `dates` containing the date `d` Args: dates (np.ndarray): array of dates (or numbers really) in sorted order d (int, float): number to search for Returns: int: index of `dates` containing value `d` """ return np.searchsorted(dates, d, side='right')
[docs]def is_integer(s): """ Returns True if `s` is an integer """ try: int(s) return True except: return False
[docs]def copy_dict_filter_key(d, regex): """ Copy a dict recursively, but only if key doesn't match regex pattern """ out = {} for k, v in six.iteritems(d): if not re.match(regex, k): if isinstance(v, dict): out[k] = copy_dict_filter_key(v, regex) else: out[k] = v return out