This notebook conducts time-series forecasting of vegetation condition (NDVI) using SARIMAX, a variation on autoregressive-moving-average (ARMA) models which includes an integrated (I) component to difference the timeseries so it becomes stationary, a seasonal (S) component, and has the capacity to consider exogenous (X) variables.
In this example, we will conduct a forecast on a univariate NDVI timeseries. That is, our forecast will be built on temporal patterns in NDVI. Conversely, multivariate approaches can account for influences of variables such as soil moisture and rainfall.
In this notebook, we generate a NDVI timeseries from Sentinel-2, then use it develop a forecasting algorithm.
The following steps are taken:
Import Python packages that are used for the analysis.
Important note: Scipy has updated and has some incompatibilities with old versions of statsmodels. If the loading packages cell below returns an error, try running pip install statsmodels
or pip install statsmodels --upgrade
in a code cell, then load the packages again.
%matplotlib inline
from itertools import product
import datacube
import numpy as np
import pandas as pd
import statsmodels.api as sm
import xarray as xr
from datacube import Datacube
from deafrica_tools.bandindices import calculate_indices
from deafrica_tools.dask import create_local_dask_cluster
from deafrica_tools.datahandling import load_ard
from deafrica_tools.plotting import display_map
from matplotlib import pyplot as plt
from statsmodels.tools.eval_measures import rmse
from tqdm.notebook import tqdm
Dask can be used to better manage memory use down and conduct the analysis in parallel. For an introduction to using Dask with Digital Earth Africa, see the Dask notebook.
Note: We recommend opening the Dask processing window to view the different computations that are being executed; to do this, see the Dask dashboard in DE Africa section of the Dask notebook.
To use Dask, set up the local computing cluster using the cell below.
create_local_dask_cluster()
Client
|
Cluster
|
dc = datacube.Datacube(app="NDVI_forecast")
lat
, lon
: The central latitude and longitude to analyse. In this example we'll use an agricultural area in Ethiopia.buffer
: The number of square degrees to load around the central latitude and longitude. For reasonable loading times, set this as 0.1 or lower.products
: The satellite data to load, in the example we will use Sentinel-2.time_range
: The date range to analyse. The longer the date-range, the more data the model has to derive patterns in the NDVI timeseries.freq
: The frequency we want to resample the time-series to e.g. for monthly time steps use '1M'
, for fortinightly use '2W'
.forecast_length
: The length of time beyond the latest observation in the dataset that we want the model to forecast, expressed in units of resample frequency freq
. A longer forecast_length
means greater forecast uncertainty. resolution
: The pixel resolution (in metres) to use for loading Sentinel-2 data.dask_chunks
: How to chunk the datasets to work with dask.# Define the analysis region (Lat-Lon box)
lat, lon = 8.5593, 40.6975
buffer = 0.04
# the satellite product to load
products = "s2_l2a"
# Define the time window for defining the model
time_range = ("2017-01-01", "2022-01")
# resample frequency
freq = "1M"
# number of time-steps to forecast (in units of `freq`)
forecast_length = 6
# resolution of Sentinel-2 pixels
resolution = (-20, 20)
# dask chunk sizes
dask_chunks = {"x": 500, "y": 500, "time": -1}
lon = (lon - buffer, lon + buffer)
lat = (lat - buffer, lat + buffer)
display_map(lon, lat)
Using the parameters we defined above.
# set up a datcube query object
query = {
"x": lon,
"y": lat,
"time": time_range,
"measurements": ["red", "nir"],
"output_crs": "EPSG:6933",
"resolution": resolution,
"resampling": {"fmask": "nearest", "*": "bilinear"},
}
# load the satellite data
ds = load_ard(dc=dc, dask_chunks=dask_chunks, products=products, **query)
print(ds)
Using pixel quality parameters for Sentinel 2 Finding datasets s2_l2a Applying pixel quality/cloud mask Returning 702 time steps as a dask array <xarray.Dataset> Dimensions: (time: 702, y: 505, x: 387) Coordinates: * time (time) datetime64[ns] 2017-01-06T07:42:19 ... 2022-01-30T07:... * y (y) float64 1.093e+06 1.093e+06 ... 1.083e+06 1.083e+06 * x (x) float64 3.923e+06 3.923e+06 ... 3.931e+06 3.931e+06 spatial_ref int32 6933 Data variables: red (time, y, x) float32 dask.array<chunksize=(702, 500, 387), meta=np.ndarray> nir (time, y, x) float32 dask.array<chunksize=(702, 500, 387), meta=np.ndarray> Attributes: crs: EPSG:6933 grid_mapping: spatial_ref
Load the cropland mask over the region of interest. The default region we're analysing is in Ethiopia, so we need to load either the crop_mask product which covers the entire African continent, or the crop_mask_eastern product, which cover the countries of Ethiopia, Kenya, Tanzania, Rwanda, and Burundi. If you change the analysis region from the default one, you may need to load a different crop mask - see the docs page to find out more.
cm = dc.load(
product="crop_mask",
time=("2019"),
measurements="filtered",
resampling="nearest",
like=ds.geobox,
).filtered.squeeze()
# convert the missing values (255) to NaN
cm = cm.where(cm != 255)
cm.plot.imshow(add_colorbar=False, figsize=(6, 6))
plt.title("Cropland Extent");
Now we will use the cropland map to mask the regions in the Sentinel-2 data that only have cropping.
ds = ds.where(cm == 1)
After calculating NDVI, we will smooth and interpolate the data to ensure we are working with a consistent time-series. This is a very important step in the workflow and there are many ways to smooth, interpolate, gap-fill, remove outliers, or curve-fit the data to ensure a consistent time-series. If not using the default example, you may have to define additional methods to those used here.
To do this we take two steps:
# calculate NDVI
ndvi = calculate_indices(ds, "NDVI", drop=True, satellite_mission="s2")
Dropping bands ['red', 'nir']
# resample and smooth
window = 4
ndvi = (
ndvi.resample(time=freq)
.mean()
.rolling(time=window, min_periods=1, center=True)
.mean()
)
In this example, we're generating a forecast on a simple 1D timeseries. This time-series represents the spatially averaged NDVI at each time-step in the series.
In this step, all the calculations above are triggered and the dataset is brought into memory so this step can take a few minutes to complete.
ndvi = ndvi.mean(["x", "y"])
ndvi = ndvi.NDVI.compute()
CPLReleaseMutex: Error = 1 (Operation not permitted)
ndvi.plot(figsize=(11, 5), linestyle="dashed", marker="o")
plt.title("NDVI")
plt.ylim(0, ndvi.max().values + 0.05);
Cross-validation is a common method for evaluating model performance. It involves dividing data into a training set on which the model is trained, and test (or validation) set, to which the model is applied to produce predictions which are compared against actual values (that weren't used in model training).
ndvi = ndvi.drop("spatial_ref").to_dataframe()
train_data = ndvi["NDVI"][
: len(ndvi) - 10
] # remove the last ten observations and keep them as test data
test_data = ndvi["NDVI"][len(ndvi) - 10 :]
SARIMAX models are fitted with parameters for both the trend and seasonal components of the timeseries. The parameters can be defined as:
In the cell below, initial values and a range are given for the parameters above. Using range(0, 3)
means the values 0, 1, and 2 are iterated through for each of p, d, q and P, D, Q. This means that there are $3^2 \times 3^2 = 81$ possible combinations.
# Set initial values and some bounds
p = range(0, 3)
d = 1
q = range(0, 3)
P = range(0, 3)
D = 1
Q = range(0, 3)
s = 6
# Create a list with all possible combinations of parameters
parameters = product(p, q, P, Q)
parameters_list = list(parameters)
print("Number of iterations to run:", len(parameters_list))
# Train many SARIMA models to find the best set of parameters
def optimize_SARIMA(parameters_list, d, D, s):
"""
Return an ordered (ascending RMSE) dataframe with parameters,
corresponding AIC, and RMSE.
parameters_list - list with (p, q, P, Q) tuples
d - integration order
D - seasonal integration order
s - length of season
"""
results = []
best_aic = float("inf")
for param in tqdm(parameters_list):
try:
import warnings
warnings.filterwarnings("ignore")
model = sm.tsa.statespace.SARIMAX(
train_data,
order=(param[0], d, param[1]),
seasonal_order=(param[2], D, param[3], s),
).fit(disp=-1)
pred = model.predict(start=len(train_data), end=(len(ndvi) - 1))
error = rmse(pred, test_data)
except:
continue
aic = model.aic
# Save best model, AIC and parameters
if aic < best_aic:
best_model = model
best_aic = aic
best_param = param
results.append([param, model.aic, error])
result_table = pd.DataFrame(results)
result_table.columns = ["parameters", "aic", "error"]
# Sort in ascending order, lower AIC is better
result_table = result_table.sort_values(by="error", ascending=True).reset_index(
drop=True
)
return result_table
Number of iterations to run: 81
Now will will run the function above for every iteration of parameters we have defined. Depending on the number of iterations, this can take a few minutes to run. A progress bar is printed below.
# run the function above
result_table = optimize_SARIMA(parameters_list, d, D, s)
0%| | 0/81 [00:00<?, ?it/s]
The root-mean-square error (RMSE) is a common metric used to evaluate model or forecast performance. It is the standard deviation of residuals (difference between forecast and actual value) expressed in units of the variable of interest e.g. NDVI. We can calculate RMSE of our forecast because we withheld some observations as test or validation data.
We can also use either the Akaike information criterion (AIC) or Bayesian information criterion (BIC) for model selection. Both these criteria aim to optimise the trade-off between goodness of fit and model simplicity. We are aiming to find the model that can explain the most variation in the timeseries with the least complexity, as added complexity may lead to overfitting. The BIC penalises additional parameters (greater complexity) more than the AIC.
There are different schools of thought on which criterion to use. A general rule of thumb is that the BIC should be used for inference and interpretation whereas the AIC should be used for prediction. As our goal is prediction (forecasting), we could select the model with the lowest AIC, though this approach is often reserved for when there is no test data available for cross-validation.
The cell below presents the top 15 models based on AIC and the RMSE on the cross-validation.
# Sort table by the lowest AIC (Akaike Information Criteria) where the RMSE is low
result_table = result_table.sort_values("aic").sort_values("error")
print(result_table[0:15])
parameters aic error 0 (0, 0, 1, 2) -167.730806 0.018184 1 (0, 1, 2, 2) -187.450473 0.018201 2 (0, 0, 2, 1) -163.183898 0.018655 3 (0, 1, 1, 1) -186.749586 0.019855 4 (2, 2, 0, 2) -203.806815 0.020542 5 (0, 0, 2, 2) -166.948974 0.021318 6 (2, 2, 1, 2) -201.434439 0.023771 7 (2, 2, 0, 1) -205.304806 0.023861 8 (0, 1, 1, 2) -188.009648 0.024562 9 (0, 0, 1, 1) -163.806008 0.024623 10 (0, 1, 2, 1) -185.070702 0.026774 11 (2, 1, 0, 0) -193.296201 0.027527 12 (1, 1, 2, 2) -191.548041 0.028136 13 (0, 2, 1, 2) -190.940525 0.029760 14 (0, 2, 2, 2) -188.050115 0.033009
In the cell below. We wil select a model from the list above. In this case we've selected model 0
as it has the lowest RMSE, though you can select any model by setting the index number in the cell below using the model_sel_index
parameter.
# selected model
model_sel_index = 0
# store parameters from selected model
p, q, P, Q = result_table.iloc[model_sel_index].parameters
print(result_table.iloc[model_sel_index])
# fit the model with the parameters identified above
best_model = sm.tsa.statespace.SARIMAX(
train_data, order=(p, d, q), seasonal_order=(P, D, Q, s)
).fit(disp=-1)
parameters (0, 0, 1, 2) aic -167.730806 error 0.018184 Name: 0, dtype: object
There are some typical plots we can use to evaluate our model.
Standardised residuals (top-left) The standardised residuals are plotted against x (time) values. This allows us to check that variance (distance of residuals from 0) is constant across time values. There should be no obvious patterns.
Histogram and estimated density (top-right) A kernel density estimation (KDE) is an estimated probability density function fitted on the actual distribution (histogram) of standardised residuals. A normal distribution (N (0,1)) is shown for reference. This plot shows that the distribution of our standardised residuals is close to normally distributed.
Normal quantile-quantile (Q-Q) plot (bottom-left) This plot shows 'expected' or 'theoretical' quantiles drawn from a normal distribution on the x-axis against quantiles taken from the sample of residuals on the y-axis. If the observations in blue match the 1:1 line in red, then we can conclude that our residuals are normally distributed.
Correlogram (bottom-right) The correlations for lags greater than 0 should not be statistically significant. That is, they should not be outside the blue ribbon.
Note: The Q-Q plot and correlogram generated for model
0
show there is some pattern in the residuals. That is, there is remaining variation in the data which the model has not accounted for. You could experiment with different parameter values or model selection in the prior steps to see if this can be addressed.
fig = plt.figure(figsize=(16, 9))
fig = best_model.plot_diagnostics(lags=25, fig=fig)
We saved the last 10 observations as test data above. Now we can use our model to predict NDVI for those time-steps and compare those predictions with actual values. We can do this visually in the graph below and also quantify the error with the root-mean-square error (RMSE).
pred = best_model.predict(start=len(train_data), end=(len(ndvi) - 1))
plt.figure(figsize=(11, 5))
pred.plot(label="forecast", linestyle="dashed", marker="o")
train_data.plot(label="training data", linestyle="dashed", marker="o")
test_data.plot(label="test data", linestyle="dashed", marker="o")
plt.legend(loc="upper left");
To forecast NDVI into the future, we'll run a model on the entire time series so we can include the latest observations. We can see that the forecast uncertainty, expressed as the 95% confidence interval, increases with time.
final_model = sm.tsa.statespace.SARIMAX(
ndvi, order=(p, d, q), seasonal_order=(P, D, Q, s)
).fit(disp=-1)
yhat = final_model.get_forecast(forecast_length);
fig, ax = plt.subplots(1, 1, figsize=(11, 5))
yhat.predicted_mean.plot(label="NDVI forecast", ax=ax, linestyle="dashed", marker="o")
ax.fill_between(
yhat.predicted_mean.index,
yhat.conf_int()["lower NDVI"],
yhat.conf_int()["upper NDVI"],
alpha=0.2,
)
ndvi[-36:].plot(label="Observaions", ax=ax, linestyle="dashed", marker="o")
plt.legend(loc="upper left");
Our forecast looks reasonable in the context of the timeseries above.
License: The code in this notebook is licensed under the Apache License, Version 2.0. Digital Earth Australia data is licensed under the Creative Commons by Attribution 4.0 license.
Contact: If you need assistance, please post a question on the Open Data Cube Slack channel or on the GIS Stack Exchange using the open-data-cube
tag (you can view previously asked questions here).
If you would like to report an issue with this notebook, you can file one on Github.
Last modified: January 2022
Compatible datacube version:
print(datacube.__version__)
1.8.6
Last Tested:
from datetime import datetime
datetime.today().strftime("%Y-%m-%d")
'2022-07-07'