"""CCAM DRS post-processing"""
import os
import re
import argparse
from pathlib import Path
import configparser as cp
import xarray as xr
import glob
import sys
import axiom.utilities as au
import axiom.schemas as axs
import numpy as np
import importlib
from pprint import pprint
from calendar import monthrange
import pandas as pd
from datetime import datetime, timedelta
from uuid import uuid4
import re
from axiom.exceptions import ResolutionDetectionException, MalformedDRSJSONPayloadException
from cerberus import Validator
from axiom.drs.domain import Domain
from axiom.config import load_config
import shutil
[docs]def is_fixed_variable(config, variable):
"""Determine if a variable is listed in in the config as fixed.
Args:
config (dict): Configuration dictionary
variable (str): Variable name.
Returns:
bool: True if fixed, False otherwise.
"""
return variable in config['fixed_variables']
[docs]def get_template(config, key):
"""Get an interpolation template out of the config.
Args:
config (dict): Dictionary.
key (str): Template key.
Returns:
str: Template.
"""
# Get the template out of the configuration
template = config['templates'][key]
# Convert list to string (helps with really long templates)
if isinstance(template, list):
# The delimiter is the first item in the list
template = template[0].join(template[1:])
return template
[docs]def get_domains(resolution, frequency, variable_fixed, no_frequencies):
"""Get the domains for the arguments provided.
Args:
resolution (int): Resolution
frequency (str): Frequency
variable_fixed (bool): Is the variable fixed?
no_frequencies (bool): True if no frequencies were provided.
Raises:
Exception: When the arguments provided don't yield any domain information.
Returns:
dict: Dictionary of domain information.
"""
is_cordex = frequency == 'cordex'
domains, frequencies, resolution_dir, degrees = None, None, None, None
# Quickly bail if necessary
assert resolution in [50, 5, 10, 12]
if resolution == 50:
degrees = 0.5
resolution_dir = f'{resolution}km'
if not is_cordex:
domains = ['AUS-50']
else:
domains = ['AUS-44i']
frequencies = ['1D', '1M']
if variable_fixed:
frequencies = ['1D']
if no_frequencies:
frequencies = ['3H', '1D', '1M']
elif resolution == 5:
domains = ['VIC-5']
degrees = 0.05
resolution_dir = f'{resolution}km'
elif resolution == 10:
domains = ['SEA-10', 'TAS-10']
resolution_dir = '5km',
degrees = 0.1
elif resolution == 12:
if not is_cordex:
raise Exception('Domain not known.')
domains = ['AUS-44i']
frequencies = ['1D', '1M']
if variable_fixed:
frequencies = ['1D']
# Return everything
return dict(
domains=domains,
frequencies=frequencies,
resolution_dir=resolution_dir,
degrees=degrees
)
[docs]def standardise_units(ds):
"""Standardise units.
Only converts mm to m for now.
Args:
ds (xarray.Dataset): Dataset
Returns:
xarray.Dataset: Dataset with units standardised.
"""
for variable in ds.data_vars.keys():
da = ds[variable]
# Convert mm to m
if ds[variable].attrs['units'] == 'mm':
ds[variable] = metadata(da / 1000, units='m')
return ds
[docs]def detect_resolution(paths):
"""Attempt to detect the input resolution from a list of paths.
Args:
paths (list): List of file paths.
Raises:
ResolutionDetectionException: If there are too many possible resolutions in a path.
ResolutionDetectionException: If there are inconsistent resolutions detected between paths.
Returns:
int: Resolution in km.
"""
# Set up pattern for searching
regex = r'([0-9.]*km)'
res = None
for path in paths:
matches = re.findall(regex, path)
# Too many options to choose from
if len(matches) != 1:
raise ResolutionDetectionException(f'Unable to detect resolution from {path}, there are too many possibilities.')
# First detection
if res is None:
res = matches[0]
# Already detected, but not the same
elif res != matches[0]:
raise ResolutionDetectionException(f'Detected resolutions are inconsistent between supplied paths.')
# Made it this far, we have a detectable resolution
return float(res.replace('km', ''))
def _center_date(dt):
"""Centre the date for compatibility with CDO-processed data.
Args:
dt (object): Date object.
Returns:
same as called: Date object with day set to middle of the month.
"""
num_days = monthrange(dt.year, dt.month)[1]
return dt.replace(day=num_days // 2)
[docs]def postprocess_cordex(ds):
"""For CORDEX processing, there is some minor postprocessing that happens.
Args:
ds (xarray.Dataset): Data.
Returns:
xarray.Dataset: Data with postprocessing applied.
"""
# Time coordinates need to be centered into the middle of the month
centered_times = ds.time.to_pandas().apply(_center_date).values
ds = ds.assign_coords(dict(time=centered_times))
return ds
[docs]def parse_domain_directive(directive):
"""Parse a domain directive (i.e. from CLI).
Args:
directive (str): Domain directive of the form name,dx,lat_min,lat_max,lon_min,lon_max
Returns:
dict : Domain dictionary
Raises:
ValueError : If the domain cannot be parsed.
"""
name, *directives = directive.split(',')
dx, lat_min, lat_max, lon_min, lon_max = map(float, directives)
return dict(
name=name,
dx=dx,
lat_min=lat_min,
lat_max=lat_max,
lon_min=lon_min,
lon_max=lon_max
)
[docs]def parse_domain(directive):
"""Parse a domain directive.
Domains are of the form: "name,fx,lat_min,lat_max,lon_min,lon_max"
Args:
directive (str) : Domain directive of the form name,dx,lat_min,lat_max,lon_min,lon_max.
Returns:
dict : Domain specification.
Raises:
AssertionError : When the directive is missing componenents.
TypeError : When the directive is unable to be parsed.
"""
segments = directive.split(',')
assert len(segments) == 6
name = segments[0]
dx, lat_min, lat_max, lon_min, lon_max = [float(s) for s in segments[1:]]
# Return a dictionary of the parsed information.
return dict(
name=name,
dx=dx,
lat_min=lat_min,
lat_max=lat_max,
lon_min=lon_min,
lon_max=lon_max
)
[docs]def load_domain_config():
"""Load domain configuration out of installed data dir.
Returns:
configparser.Config: Configuration object.
"""
return load_config('domains')
[docs]def get_domain(key):
"""Load a domain out of the internal configuration.
Args:
key (str): Domain key or parseable domain directive.
Returns:
axiom.domain.Domain : Domain object.
"""
# Load the domain configuration as a configuration object.
domain_config = load_domain_config()
return Domain.from_config(key, domain_config)
[docs]def is_registered_domain(key):
"""Quick shortcut to see if the domain is already registered in the system.
Args:
key (str): Domain key.
Returns:
bool: True if registered, False otherwise.
"""
return key in load_domain_config().keys()
[docs]def load_processor(model_key, proc_type='pre'):
"""Load a pre-or-post processor for the model, if one exists.
Args:
model_key (str): Model.
Returns:
callable: Function that takes an xarray.Dataset as input.
"""
logger = au.get_logger(__name__)
try:
mod = importlib.import_module(f'axiom.drs.processing.{model_key}')
processor = getattr(mod, f'{proc_type}process_{model_key}')
logger.info(f'Found {proc_type}processor for {model_key}')
return processor
except:
logger.warning(f'No {proc_type}processor found for {model_key}, returning empty function.')
return lambda ds, *args, **kwargs: ds
[docs]def load_preprocessor(model_key):
"""Shorthand for the load_processor function.
Args:
model_key (str): Model
"""
return load_processor(model_key=model_key, proc_type='pre')
[docs]def load_postprocessor(model_key):
"""Shorthand for the load_processor function.
Args:
model_key (str): Model
"""
return load_processor(model_key=model_key, proc_type='post')
[docs]def interpolate_context(context):
"""Interpolate the context dictionary into itself, filling all placeholders.
Args:
context (dict): Context dictionary.
Returns:
dict : Interpolated context.
"""
logger = au.get_logger(__name__)
for key, value in context.items():
# context[key] = value % context
new_value = str(value) % context
logger.debug(f'{key} = {new_value}')
context[key] = new_value
return context
[docs]def get_uninterpolated_placeholders(string):
"""Check if a string has any remaining uninterpolated values.
Args:
string (string): String object.
Returns:
list : List of uninterpolated values.
"""
# Regex to find matches
matches = re.findall(r'%\(([a-zA-Z0-9_-]+)\)s', string)
# Convert to set to remove duplicates, convert back and return
return sorted(list(set(matches)))
[docs]def is_time_invariant(ds):
"""Test if the dataset is time-invariant (has no time coordinate)
Args:
ds (xarray.Dataset or xarray.DataArray): Data
Return:
bool : True if no 'time' coordinate detected, False otherwise
"""
return 'time' not in list(ds.coords.keys())
[docs]def is_error_recoverable(exception, recoverable_errors):
"""Determine if an error is recoverable based on the presence of certain text in the stack trace.
Args:
exception (Exception): Exception object, with message attribute.
recoverable_errors (list): List of recoverable errors (regexes) that are permitted.
Returns:
bool : True if recoverable, False otherwise.
"""
if hasattr(exception, 'message'):
message = exception.message
else:
message = str(exception)
# Check for the pattern
for pattern in recoverable_errors:
if re.search(pattern, message):
return True
return False
[docs]def assemble_qsub_vars(**kwargs):
"""Assemble variables into a qsub-compliant -v format, without the -v.
Args:
**kwargs (dict): Keyword arguments to convert into qsub variables.
Returns:
str : String of qsub variables.
"""
return ','.join([f'{k}={v}' for k, v in kwargs.items()])
[docs]def assemble_qsub_command(jobscript, directives, **context):
"""Assemble the qsub command.
Args:
jobscript (str) : Path to the jobscript.
directives (list) : List of directives to interpolate.
**context (dict) : Context dictionary to interpolate into the directives.
Returns:
str : The qsub command.
"""
# Collapse the directives into a single string and interpolate
directives = ' '.join(directives) % context
return f'qsub {directives} {jobscript}'
[docs]def generate_user_config():
"""Generate the user .axiom directory with all installed data files.
Args:
overwrite (bool): Overwrite existing .axiom directory.
"""
logger = au.get_logger(__name__)
# Work out the users home directory (os independent)
axiom_dir = os.path.join(
str(Path.home()),
'.axiom'
)
# Does it already exist? back it up with a timestamp of the current tim
if os.path.exists(axiom_dir):
timestamp = datetime.now().strftime('%Y%m%d%H%M%S')
backup_dir = f'{axiom_dir}_{timestamp}'
print(f'Backing up existing .axiom directory to {backup_dir}')
shutil.move(axiom_dir, backup_dir)
# Create the directory and copy the files over.
logger.info(f'Creating new .axiom directory at {axiom_dir}')
os.makedirs(axiom_dir, exist_ok=True)
config_files = au.auto_glob(os.path.join(
importlib.resources.files('axiom'),
'data/*.json'
))
for src in config_files:
dst = os.path.join(axiom_dir, os.path.basename(src))
logger.info(f'Copying {src} to {dst}')
shutil.copy(src, dst)
logger.info('User configuration generated successfully.')
[docs]def filter_by_variable_name(filepaths, variable):
"""Filter for filenames that include the variable name.
Args:
filepaths (list): List of filepaths.
Returns:
list : List of filepaths that include the variable name.
"""
config = load_config('drs')
# Bail out if not filtering
if not config['filename_filtering']['variable']:
return filepaths
# Get the regex information, interpolate the variable name
pattern_template = config['filename_filtering']['variable_regex']
pattern = pattern_template % dict(variable=variable)
# Filter for the variable name
filtered_filepaths = list()
for filepath in filepaths:
filename = os.path.basename(filepath)
if re.search(pattern, filename):
filtered_filepaths.append(filepath)
return filtered_filepaths
[docs]def get_start_and_end_dates(year, output_frequency):
"""Get the start and end dates for interpolation context.
Args:
year (int): Current year
output_frequency (str): Frequency.
"""
start_date = f'{year}0101' if output_frequency[-1] != 'M' else f'{year}01'
end_date = f'{year}1231' if output_frequency[-1] != 'M' else f'{year}12'
return start_date, end_date
[docs]def generate_years_list(start_year, end_year):
"""Generate a list of years (decades) to process.
Args:
start_year (int): Start year.
end_year (int): End year.
Returns:
iterator : Years to process.
"""
return range(start_year, end_year+1, 10)