Estimating Chl-a from Remote Sensing - a CDE Machine Learning Approach

In this notebook, we explore state-of-the-art machine learning approaches to the task of estimating the concentration of chlorophyll-a (chl-a) from remote sensing data (satellites). The goal of this notebook is to demonstrate the possibilities of publicly available datasets in solving certain real-world environmental problems, and the Machine Learning opportunities that these problems invite.

Olivier Graffeuille
Numpy, Pandas, Sklearn, PyTorch
We thank the various regional councils for their help gathering this important environmental ground-truth data, and the European Space Agency for their satellite data.

Estimating Chl-a from Remote Sensing data with Machine Learning

In this notebook, we explore state-of-the-art machine learning approaches to the task of estimating the concentration of chlorophyll-a (chl-a) from remote sensing data (satellites).

The goal of this notebook is to demonstrate the possiblities of publically available datasets in solving certain real-world environmental problems, and the Machine Learning opportunities that these problems invite.

The problem

Harmful Algal Blooms

Algal blooms are a rapid build up of algae in an aquatic system. These can be harmful to local ecosystems, aquaculture, and even human health. As a result, councils monitor the concentration of algae in important lakes around the country. This is done by manually collecting samples of water, often by boat, and analysing the samples in a lab - a process which is expensive and slow, resulting in poor spatial and temporal resolution.

Satellite Data

To mitigate these issues, environmental scientists have begun using satellite data to estimate the concentration of algae. Instead of estimating chlorophyll-a directly, the concentration of chlorophyll-a, the green pigment in algae, is estimated instead as a proxy as it has a more consistent optical signal.

In this work, we use the Ocean and Land Colour Instrument (OLCI) sensor of the Sentinel-3 satellite. This satellite has poor spatial resolution (~300m pixels) but high spectral resolution. To understand this, we need to know that satellites images are multispectral - this means that they don't just measure intensity or red/green/blue light, but measure light at a range of defined frequencies called bands. OLCI has 21 bands, specifically placed to effectively estimate aquatic parameters such as Chl:

The Task

This multispectral will be the input (features) to our model - the output (target) will be the ground-truth (called in-situ) concentration of chl-a measured by environmental scientists. More specifically, we will predict the chl-a concentration from the multispectral data from the pixel of satellite data that corresponds to where the in-situ data was collected. We don't consider adjacent pixels because these don't have labels, and can have vastly different concentrations of algae.

import numpy as np import pandas as pd import matplotlib.pyplot as plt import sklearn
Step 0 - The data

The satellite data

The Sentinel-3 OLCI data can be accessed publicly at

A TAIAO page explaining this dataset can be visited at

The in-situ data

The in-situ data is available to public and can be obtained by corresponding with the various regional councils - however, the process of obtaining this data tends to be different for each council.

The data processing

To avoid a lengthy data manipulation process, in this notebook we will use data that has already been processed. The following steps have been taken:

  • Collecting satellite data - mass downloading of OLCI data over New Zealand, including atmospheric correction using SNAP Rayleigh algorithm
  • Cleaning satellite data - filtering for cloudy pixels, pixels which overlap with land surfaces, etc
  • Collecting in-situ data - communication with various New Zealand regional councils
  • Matching satellite and in-situ data - match most appropriate pixel to each in-situ data, taking into account spatial distance, temporal distance, etc

This is the resulting dataset:

data_all = pd.read_csv(r"data\NZ_chl\data.csv") display(data_all)

346 rows × 18 columns

Each row corresponds to a ground-truth data point. The multispectral satellite data corresponding to the pixel that was matched to this point are columns "Rayleigh*rBRR**", and the ground truth chl-a data is the "Chl-a" column. Note that the multispectral data has been normalised (as is convention in remote sensing), but that the pre-normalised spectral sum is recorded in the "sum" column, to avoid loss of information. We also note that the Chl-a is logged due to its natural exponential distribution.

Step 1 - Building a baseline model

Let's try to build a simple RF model to estimate Chl-a. We note that in this notebook, we will simply build one model per approach - avoiding repeated sampling or cross-validation for simplicity.

from sklearn.ensemble import RandomForestRegressor from sklearn.model_selection import train_test_split from sklearn.metrics import r2_score, mean_squared_error seed = 0 np.random.seed(seed) X = data_all.drop("Chl-a", axis=1).values y = data_all["Chl-a"].values X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=seed) print(X.shape, y.shape) print(X_train.shape, y_train.shape) print(X_test.shape, y_test.shape)

(346, 17) (346,) (276, 17) (276,) (70, 17) (70,)

# Create RF model object RF_model = RandomForestRegressor() # Train model, y_train) # Make prediction on test data y_pred = RF_model.predict(X_test) # Evaluate performance of model RF_r2 = r2_score(y_test, y_pred) RF_mse = mean_squared_error(y_test, y_pred) print("Random Forest Performance - MSE: {:.4f}, R2: {:.4f}".format(RF_mse, RF_r2)) # Plot prediction vs. real chlorophyll-a values fig, ax = plt.subplots(1,1,figsize=(6,6)) x_y_line = [np.min(y_test), np.max(y_test)] plt.plot(x_y_line, x_y_line, "k--") plt.scatter(y_test, y_pred) plt.xlabel("True ln(Chl-a [g/m^2])") plt.ylabel("Predicted ln(Chl-a [g/m^2])") plt.title("Random Forest Predictions")

Random Forest Performance - MSE: 0.3623, R2: 0.7447

Random Forest Predictions
Random Forest Predictions
Step 2 - Mixture Density Networks

We can see that this simple model is able to learn the relationship between satellite data and in-situ algae data with relative accuracy. This is very exciting as it indicates that the satellite data we've matched to the in-situ data is informative with regards to algae concentration - so now let's see how we can develop a better ML model for this task. This random forest model will act as a baseline to evaluate the quality of future models.

Conditional Density Estimation

Recently, researchers have advocated the use of Conditional Density Estimation (CDE) for this task of estimating Chl-a. CDE is the task whereby, instead of designing a model that predicts a single continuous variable (as we do in regression tasks), we instead design a model that outputs a probability density estimation for the target variable. CDE has the advantages of being able to model uncertainty, and being able to effectively model tasks where there may be multiple correct labels for a single input.

CDE makes sense for this task, because we're modelling an inverse problem - we're predicting the underlying physical variables from their observed signals. A consequence of this is that it's possible for multiple combinations of things in the water to produce the same optical signal, and hence given some multispectral data, multiple possible concentrations of Chl-a could have produced it. The case for CDE is made in "Seamless retrievals of chlorophyll-a from Sentinel-2 (MSI) and Sentinel-3 (OLCI) in inland and coastal waters: A machine-learning approach" by Pahlevan et al., 2020.

If this isn't intuitive, imagine a bucket of water. If we add some drops of green paint to the bucket, it's pretty easy to estimate the concentration of green paint in the bucket. However, if we then add some yellow, red, maybe even turquoise paint to the bucket, making a brown mix, it becomes hard (or even impossible) to accurately measure how much green paint is in the bucket.

Mixture Density Networks

So, CDE is a sensible approach to modelling Chl-a. The most common approach to CDE are Mixture Density Networks (MDNs), a type of neural network where the output layer parameterises a mixture distribution. This blog explains MDNs in great depth and clarity:

import torch import torch.nn as nn import torch.optim as optim import torch.distributions as D import torch.nn.functional as F torch.set_default_tensor_type('torch.cuda.FloatTensor') device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Mixture density network module - this defines out MDN architecture class MDN(nn.Module): def __init__(self, input_dims, num_hidden, num_gaussians): super(MDN, self).__init__() self.input_dims = input_dims self.num_gaussians = num_gaussians self.num_hidden = num_hidden linear_layers = [nn.Linear(input_dims, num_hidden[0])] if len(num_hidden) > 1: for i in range(1, len(num_hidden)): linear_layers.append(nn.Linear(num_hidden[i-1], num_hidden[i])) self.nn_h = nn.ModuleList(linear_layers) self.nn_pi = nn.Linear(num_hidden[-1], num_gaussians) self.nn_sigma = nn.Linear(num_hidden[-1], num_gaussians) self.nn_mu = nn.Linear(num_hidden[-1], num_gaussians) def forward(self, x): for nn_h in self.nn_h: x = F.relu(nn_h(x)) sigma = torch.exp(self.nn_sigma(x)) mu = self.nn_mu(x) pi = F.softmax(self.nn_pi(x), dim=1) return mu, sigma, pi # Function to compute negative log loss of gaussian mixture which is outputted by our MDN def gaussian_mixture_NLL(y, mu, sigma, pi): mu = mu.view(-1, num_gaussians, 1) sigma = sigma.view(-1, num_gaussians, 1) comp = D.Independent(D.Normal(mu, sigma), 1) mix = D.Categorical(pi) gmm = D.MixtureSameFamily(mix, comp) nll = -gmm.log_prob(y) nll[nll > 20] = 20 # Prevents values with probability zero causing exponentially large errors nll_loss = torch.mean(nll) return nll_loss
# Define MDN input_dims = X.shape[1] num_hidden = [24,24,24] num_gaussians = 5 MDN_model = MDN(input_dims, num_hidden, num_gaussians) # Define training parameters lr = 1e-3 batch_size = 32 epochs_max = 1000 torch.manual_seed(seed) ##### Process data to use in MDN ##### # Reshape y to a 2d column vector, easier to work in pytorch y = y.reshape(-1,1) # Create validation set (0.2/0.2/0.6 test/validation/train) X_not_test, X_test, y_not_test, y_test = train_test_split(X, y, test_size=0.2, random_state=seed) X_train, X_vali, y_train, y_vali = train_test_split(X_not_test, y_not_test, test_size=0.25, random_state=seed) # Convert to pytorch tensors X_train_ten = torch.from_numpy(X_train).to(device).float() X_vali_ten = torch.from_numpy(X_vali).to(device).float() X_test_ten = torch.from_numpy(X_test).to(device).float() y_train_ten = torch.from_numpy(y_train).to(device).float() y_vali_ten = torch.from_numpy(y_vali).to(device).float() y_test_ten = torch.from_numpy(y_test).to(device).float() # Normalise features and label X_mean = torch.mean(X_train_ten, axis=0) X_std = torch.std(X_train_ten, axis=0) X_train_ten = (X_train_ten - X_mean) / X_std X_vali_ten = (X_vali_ten - X_mean) / X_std X_test_ten = (X_test_ten - X_mean) / X_std y_mean = torch.mean(y_train_ten, axis=0) y_std = torch.std(y_train_ten, axis=0) y_train_ten = (y_train_ten - y_mean) / y_std y_vali_ten = (y_vali_ten - y_mean) / y_std y_test_ten = (y_test_ten - y_mean) / y_std # Define batches dataset_train =, y_train_ten) loader_train =, batch_size=batch_size, shuffle=True) # Define model optimiser = optim.Adam(MDN_model.parameters(), lr=lr) # Variables to draw learning curves training_losses = [] validation_losses = [] for epoch in range(epochs_max): epoch_training_loss = 0 for X_batch, y_batch in loader_train: # Make prediction mu, sigma, pi = MDN_model(X_batch) # Compute loss loss = gaussian_mixture_NLL(y_batch, mu, sigma, pi) epoch_training_loss += loss.item() * y_batch.shape[0] # Back propogation optimiser.zero_grad() loss.backward() optimiser.step() training_losses.append(epoch_training_loss / y_train.shape[0]) # Validation mu, sigma, pi = MDN_model(X_vali_ten) loss = gaussian_mixture_NLL(y_vali_ten, mu, sigma, pi) validation_losses.append(loss.item()) if len(validation_losses) > 10 and validation_losses[-1] > np.mean(validation_losses[-10:]): break # Plot learning curves fig, ax = plt.subplots(1,1,figsize=(8,6)) plt.plot(training_losses, label="Training") plt.plot(validation_losses, label="Validation") plt.title("Learning curves for MDN") plt.xlabel("Epoch") plt.ylabel("Loss (NLL)") plt.legend()
Learning Curves for MDN
Learning Curves for MDN

Making a Prediction

So now, we've trained a model that makes PDF estimates for our model. So how do we get a point estimate for each test data point, to compare the performance to the random forest? Well, we find the mode of each distribution and use this to be our prediction - this allows us to generate regression predictions despite being able to model inverse problems with MDNs.

# Make prediction for test set mu, sigma, pi = MDN_model(X_test_ten) # Returns [num_mixtures x num_bins] array of gaussian PDFs def gaussian_mixture_pdf(mu, sigma, pi, bins): n_mixtures = mu.shape[0] n_bins = len(bins) gmm_pdf = torch.zeros([n_mixtures, n_bins]) for k in range(num_gaussians): gmm_pdf = gmm_pdf + pi[:, k].view(-1, 1) * gaussian_pdf(mu[:, k].view(-1, 1), sigma[:, k].view(-1, 1), bins) return gmm_pdf # Returns 2D Gaussian [num_distributions x num_bins] def gaussian_pdf(mu, sigma, bins): sqrt_2_pi = 2.5066283 val = 1 / (sigma * sqrt_2_pi) * torch.exp(-torch.pow(bins - mu, 2) / (2 * sigma * sigma)) return val # Calculate PDF vector for each prediction bin_min, bin_max = -3, 3 bins = torch.linspace(bin_min, bin_max, 1001) pdfs = gaussian_mixture_pdf(mu, sigma, pi, bins) # Find mode of each estimate pdf_modes_inds = torch.argmax(pdfs, axis=1) pdf_modes = bins[pdf_modes_inds] # Un-normalise predictions y_pred = pdf_modes * y_std + y_mean y_pred = y_pred.detach().cpu().numpy() # Measure prediction metrics MDN_r2 = r2_score(y_test, y_pred) MDN_mse = mean_squared_error(y_test, y_pred) print("MSE: {:.4f}, R2: {:.4f}".format(MDN_mse, MDN_r2)) # Plot the estimated PDFs and corresponding modes pdfs_np = pdfs.detach().cpu().numpy() fig, axs = plt.subplots(1,2,figsize=(15,8),gridspec_kw={'width_ratios': [2, 5]}) bin_min_unnorm = bin_min * y_std.item() + y_mean.item() bin_max_unnorm = bin_max * y_std.item() + y_mean.item() axs[0].imshow(np.repeat(np.sqrt(pdfs_np), 15, axis=0), extent=[bin_min_unnorm,bin_max_unnorm,0,15]) axs[0].scatter(y_test, np.linspace(15,0,len(y_test)), c="m", marker="o", label="True") axs[0].scatter(y_pred, np.linspace(15,0,len(y_pred)), c="r", marker="x", label="Pred") axs[0].set_xlabel("True ln(Chl-a [g/m^2])") axs[0].set_ylabel("Data point") axs[0].set_title("MDN Predictions") axs[0].legend() # Plot prediction vs. real chlorophyll-a values x_y_line = [np.min(y_test), np.max(y_test)] axs[1].plot(x_y_line, x_y_line, "k--") axs[1].scatter(y_test, y_pred) axs[1].set_xlabel("True ln(Chl-a [g/m^2])") axs[1].set_ylabel("Predicted ln(Chl-a [g/m^2])") axs[1].set_title("MDN Predictions")

MSE: 0.3349, R2: 0.7640

MDN Predictions
MDN Predictions
Step 3 - Next steps

The Results

Looking at MSE and R2, this model seems to perform similarly to the random forest - but we've only tried one seed, so this result is inconclusive. Repeated sampling would allow us to accurately compare the performance of these models, but that's out of the scope of this workbook.

We could use either of these models on each water pixel in a Sentinel-3 to estimate the concentration of Chlorophyll-a for each pixel. This would allow us to create heatmaps of Chlorophyll-a concentration like the one below (Pahlevan et. al, 2020) - but this is also out of scope for this notebook.

Polymer Heatmaps
Polymer Heatmaps

Improving the MDN Model Further

An issue with MDNs is that they tend to overfit, especially with few data points - and we only have ~300 data points. This may be why the MDN model doesn't substantially outperform the RF in this notebook. I tackle this problem in my paper "Semi-Supervised Conditional Density Estimation with Wasserstein Mixture Density Networks" published in AAAI 2022, where I use unlabelled data to regularise the learning process of MDNs and improve performance. In the paper, I apply this technique to remote sensing Chl-a datasets and show the utility of the method in this framework - but that's also out of the scope of this workbook.

I hope this notebook was useful and interesting. Feel free to message me at if you have any questions regarding this work.