In this notebook, we explore state-of-the-art machine learning approaches to the task of estimating the concentration of chlorophyll-a (chl-a) from remote sensing data (satellites). The goal of this notebook is to demonstrate the possibilities of publicly available datasets in solving certain real-world environmental problems, and the Machine Learning opportunities that these problems invite.
In this notebook, we explore state-of-the-art machine learning approaches to the task of estimating the concentration of chlorophyll-a (chl-a) from remote sensing data (satellites).
The goal of this notebook is to demonstrate the possiblities of publically available datasets in solving certain real-world environmental problems, and the Machine Learning opportunities that these problems invite.
Harmful Algal Blooms
Algal blooms are a rapid build up of algae in an aquatic system. These can be harmful to local ecosystems, aquaculture, and even human health. As a result, councils monitor the concentration of algae in important lakes around the country. This is done by manually collecting samples of water, often by boat, and analysing the samples in a lab - a process which is expensive and slow, resulting in poor spatial and temporal resolution.
Satellite Data
To mitigate these issues, environmental scientists have begun using satellite data to estimate the concentration of algae. Instead of estimating chlorophyll-a directly, the concentration of chlorophyll-a, the green pigment in algae, is estimated instead as a proxy as it has a more consistent optical signal.
In this work, we use the Ocean and Land Colour Instrument (OLCI) sensor of the Sentinel-3 satellite. This satellite has poor spatial resolution (~300m pixels) but high spectral resolution. To understand this, we need to know that satellites images are multispectral - this means that they don't just measure intensity or red/green/blue light, but measure light at a range of defined frequencies called bands. OLCI has 21 bands, specifically placed to effectively estimate aquatic parameters such as Chl: https://sentinels.copernicus.eu/web/sentinel/user-guides/sentinel-3-olci/resolutions/radiometric
The Task
This multispectral will be the input (features) to our model - the output (target) will be the ground-truth (called in-situ) concentration of chl-a measured by environmental scientists. More specifically, we will predict the chl-a concentration from the multispectral data from the pixel of satellite data that corresponds to where the in-situ data was collected. We don't consider adjacent pixels because these don't have labels, and can have vastly different concentrations of algae.
import numpy as np import pandas as pd import matplotlib.pyplot as plt import sklearn
The satellite data
The Sentinel-3 OLCI data can be accessed publicly at https://scihub.copernicus.eu/
A TAIAO page explaining this dataset can be visited at https://taiao.ai/datasets/sentinel-3-remote-sensing-data.en/
The in-situ data
The in-situ data is available to public and can be obtained by corresponding with the various regional councils - however, the process of obtaining this data tends to be different for each council.
The data processing
To avoid a lengthy data manipulation process, in this notebook we will use data that has already been processed. The following steps have been taken:
This is the resulting dataset:
data_all = pd.read_csv(r"data\NZ_chl\data.csv") display(data_all)
Rayleigh_rBRR_01 | Rayleigh_rBRR_02 | Rayleigh_rBRR_03 | Rayleigh_rBRR_04 | Rayleigh_rBRR_05 | Rayleigh_rBRR_06 | Rayleigh_rBRR_07 | Rayleigh_rBRR_08 | Rayleigh_rBRR_09 | Rayleigh_rBRR_10 | Rayleigh_rBRR_11 | Rayleigh_rBRR_12 | Rayleigh_rBRR_16 | Rayleigh_rBRR_17 | Rayleigh_rBRR_18 | Rayleigh_rBRR_21 | sum | Chl-a | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.083240 | 0.087611 | 0.076884 | 0.072193 | 0.073173 | 0.071898 | 0.049810 | 0.041827 | 0.039545 | 0.039801 | 0.049114 | 0.066925 | 0.061577 | 0.049023 | 0.045723 | 0.091657 | 0.395433 | 1.098612 |
1 | 0.103721 | 0.118381 | 0.114508 | 0.125354 | 0.121288 | 0.101748 | 0.049893 | 0.038541 | 0.038068 | 0.035887 | 0.027366 | 0.032759 | 0.027330 | 0.008797 | 0.003411 | 0.052948 | 0.285036 | 0.587787 |
2 | 0.132218 | 0.148740 | 0.150831 | 0.179540 | 0.173241 | 0.148857 | 0.054293 | 0.036442 | 0.037513 | 0.034342 | -0.066170 | 0.011522 | 0.004069 | -0.021140 | -0.030790 | 0.006492 | 0.257608 | 0.587787 |
3 | 0.085762 | 0.096356 | 0.083622 | 0.081695 | 0.086986 | 0.111562 | 0.067918 | 0.052424 | 0.048378 | 0.048182 | 0.041094 | 0.052027 | 0.046095 | 0.023720 | 0.013796 | 0.060382 | 0.292001 | 1.945910 |
4 | 0.049680 | 0.057744 | 0.055835 | 0.066036 | 0.073412 | 0.115075 | 0.097171 | 0.072158 | 0.065126 | 0.065670 | 0.072923 | 0.049521 | 0.047245 | 0.031630 | 0.025981 | 0.054793 | 0.566259 | 2.772589 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
341 | 0.106724 | 0.117034 | 0.097616 | 0.083563 | 0.085665 | 0.088877 | 0.055979 | 0.046036 | 0.045310 | 0.042571 | 0.033485 | 0.053389 | 0.044693 | 0.022084 | 0.009323 | 0.067651 | 0.267427 | 1.568616 |
342 | 0.096544 | 0.103171 | 0.087800 | 0.076552 | 0.073745 | 0.072551 | 0.047403 | 0.040020 | 0.039099 | 0.038091 | 0.040281 | 0.066260 | 0.061441 | 0.049710 | 0.043277 | 0.064055 | 0.528983 | 0.916291 |
343 | 0.116344 | 0.117222 | 0.097944 | 0.095083 | 0.092039 | 0.095806 | 0.061073 | 0.047548 | 0.046872 | 0.047027 | 0.035219 | 0.042143 | 0.034332 | 0.016040 | 0.007314 | 0.047994 | 0.356656 | 2.302585 |
344 | 0.109086 | 0.122802 | 0.107111 | 0.105126 | 0.099380 | 0.087211 | 0.048677 | 0.040532 | 0.039715 | 0.038182 | 0.033794 | 0.041988 | 0.037847 | 0.025256 | 0.020983 | 0.042310 | 0.308401 | 0.587787 |
345 | 0.051279 | 0.053541 | 0.055215 | 0.060767 | 0.063464 | 0.077048 | 0.061565 | 0.056779 | 0.055658 | 0.055541 | 0.061656 | 0.069590 | 0.069290 | 0.067430 | 0.065360 | 0.075817 | 0.928310 | 1.131402 |
346 rows × 18 columns
Each row corresponds to a ground-truth data point. The multispectral satellite data corresponding to the pixel that was matched to this point are columns "Rayleigh*rBRR**", and the ground truth chl-a data is the "Chl-a" column. Note that the multispectral data has been normalised (as is convention in remote sensing), but that the pre-normalised spectral sum is recorded in the "sum" column, to avoid loss of information. We also note that the Chl-a is logged due to its natural exponential distribution.
Let's try to build a simple RF model to estimate Chl-a. We note that in this notebook, we will simply build one model per approach - avoiding repeated sampling or cross-validation for simplicity.
from sklearn.ensemble import RandomForestRegressor from sklearn.model_selection import train_test_split from sklearn.metrics import r2_score, mean_squared_error seed = 0 np.random.seed(seed) X = data_all.drop("Chl-a", axis=1).values y = data_all["Chl-a"].values X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=seed) print(X.shape, y.shape) print(X_train.shape, y_train.shape) print(X_test.shape, y_test.shape)
(346, 17) (346,) (276, 17) (276,) (70, 17) (70,)
# Create RF model object RF_model = RandomForestRegressor() # Train model RF_model.fit(X_train, y_train) # Make prediction on test data y_pred = RF_model.predict(X_test) # Evaluate performance of model RF_r2 = r2_score(y_test, y_pred) RF_mse = mean_squared_error(y_test, y_pred) print("Random Forest Performance - MSE: {:.4f}, R2: {:.4f}".format(RF_mse, RF_r2)) # Plot prediction vs. real chlorophyll-a values fig, ax = plt.subplots(1,1,figsize=(6,6)) x_y_line = [np.min(y_test), np.max(y_test)] plt.plot(x_y_line, x_y_line, "k--") plt.scatter(y_test, y_pred) plt.xlabel("True ln(Chl-a [g/m^2])") plt.ylabel("Predicted ln(Chl-a [g/m^2])") plt.title("Random Forest Predictions") plt.show()
Random Forest Performance - MSE: 0.3623, R2: 0.7447
We can see that this simple model is able to learn the relationship between satellite data and in-situ algae data with relative accuracy. This is very exciting as it indicates that the satellite data we've matched to the in-situ data is informative with regards to algae concentration - so now let's see how we can develop a better ML model for this task. This random forest model will act as a baseline to evaluate the quality of future models.
Conditional Density Estimation
Recently, researchers have advocated the use of Conditional Density Estimation (CDE) for this task of estimating Chl-a. CDE is the task whereby, instead of designing a model that predicts a single continuous variable (as we do in regression tasks), we instead design a model that outputs a probability density estimation for the target variable. CDE has the advantages of being able to model uncertainty, and being able to effectively model tasks where there may be multiple correct labels for a single input.
CDE makes sense for this task, because we're modelling an inverse problem - we're predicting the underlying physical variables from their observed signals. A consequence of this is that it's possible for multiple combinations of things in the water to produce the same optical signal, and hence given some multispectral data, multiple possible concentrations of Chl-a could have produced it. The case for CDE is made in "Seamless retrievals of chlorophyll-a from Sentinel-2 (MSI) and Sentinel-3 (OLCI) in inland and coastal waters: A machine-learning approach" by Pahlevan et al., 2020.
If this isn't intuitive, imagine a bucket of water. If we add some drops of green paint to the bucket, it's pretty easy to estimate the concentration of green paint in the bucket. However, if we then add some yellow, red, maybe even turquoise paint to the bucket, making a brown mix, it becomes hard (or even impossible) to accurately measure how much green paint is in the bucket.
Mixture Density Networks
So, CDE is a sensible approach to modelling Chl-a. The most common approach to CDE are Mixture Density Networks (MDNs), a type of neural network where the output layer parameterises a mixture distribution. This blog explains MDNs in great depth and clarity: https://towardsdatascience.com/a-hitchhikers-guide-to-mixture-density-networks-76b435826cca
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributions as D
import torch.nn.functional as F
torch.set_default_tensor_type('torch.cuda.FloatTensor')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Mixture density network module - this defines out MDN architecture
class MDN(nn.Module):
def __init__(self, input_dims, num_hidden, num_gaussians):
super(MDN, self).__init__()
self.input_dims = input_dims
self.num_gaussians = num_gaussians
self.num_hidden = num_hidden
linear_layers = [nn.Linear(input_dims, num_hidden[0])]
if len(num_hidden) > 1:
for i in range(1, len(num_hidden)):
linear_layers.append(nn.Linear(num_hidden[i-1], num_hidden[i]))
self.nn_h = nn.ModuleList(linear_layers)
self.nn_pi = nn.Linear(num_hidden[-1], num_gaussians)
self.nn_sigma = nn.Linear(num_hidden[-1], num_gaussians)
self.nn_mu = nn.Linear(num_hidden[-1], num_gaussians)
def forward(self, x):
for nn_h in self.nn_h:
x = F.relu(nn_h(x))
sigma = torch.exp(self.nn_sigma(x))
mu = self.nn_mu(x)
pi = F.softmax(self.nn_pi(x), dim=1)
return mu, sigma, pi
# Function to compute negative log loss of gaussian mixture which is outputted by our MDN
def gaussian_mixture_NLL(y, mu, sigma, pi):
mu = mu.view(-1, num_gaussians, 1)
sigma = sigma.view(-1, num_gaussians, 1)
comp = D.Independent(D.Normal(mu, sigma), 1)
mix = D.Categorical(pi)
gmm = D.MixtureSameFamily(mix, comp)
nll = -gmm.log_prob(y)
nll[nll > 20] = 20 # Prevents values with probability zero causing exponentially large errors
nll_loss = torch.mean(nll)
return nll_loss
# Define MDN input_dims = X.shape[1] num_hidden = [24,24,24] num_gaussians = 5 MDN_model = MDN(input_dims, num_hidden, num_gaussians) # Define training parameters lr = 1e-3 batch_size = 32 epochs_max = 1000 torch.manual_seed(seed) ##### Process data to use in MDN ##### # Reshape y to a 2d column vector, easier to work in pytorch y = y.reshape(-1,1) # Create validation set (0.2/0.2/0.6 test/validation/train) X_not_test, X_test, y_not_test, y_test = train_test_split(X, y, test_size=0.2, random_state=seed) X_train, X_vali, y_train, y_vali = train_test_split(X_not_test, y_not_test, test_size=0.25, random_state=seed) # Convert to pytorch tensors X_train_ten = torch.from_numpy(X_train).to(device).float() X_vali_ten = torch.from_numpy(X_vali).to(device).float() X_test_ten = torch.from_numpy(X_test).to(device).float() y_train_ten = torch.from_numpy(y_train).to(device).float() y_vali_ten = torch.from_numpy(y_vali).to(device).float() y_test_ten = torch.from_numpy(y_test).to(device).float() # Normalise features and label X_mean = torch.mean(X_train_ten, axis=0) X_std = torch.std(X_train_ten, axis=0) X_train_ten = (X_train_ten - X_mean) / X_std X_vali_ten = (X_vali_ten - X_mean) / X_std X_test_ten = (X_test_ten - X_mean) / X_std y_mean = torch.mean(y_train_ten, axis=0) y_std = torch.std(y_train_ten, axis=0) y_train_ten = (y_train_ten - y_mean) / y_std y_vali_ten = (y_vali_ten - y_mean) / y_std y_test_ten = (y_test_ten - y_mean) / y_std # Define batches dataset_train = torch.utils.data.TensorDataset(X_train_ten, y_train_ten) loader_train = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=True) # Define model optimiser = optim.Adam(MDN_model.parameters(), lr=lr) # Variables to draw learning curves training_losses = [] validation_losses = [] for epoch in range(epochs_max): epoch_training_loss = 0 for X_batch, y_batch in loader_train: # Make prediction mu, sigma, pi = MDN_model(X_batch) # Compute loss loss = gaussian_mixture_NLL(y_batch, mu, sigma, pi) epoch_training_loss += loss.item() * y_batch.shape[0] # Back propogation optimiser.zero_grad() loss.backward() optimiser.step() training_losses.append(epoch_training_loss / y_train.shape[0]) # Validation mu, sigma, pi = MDN_model(X_vali_ten) loss = gaussian_mixture_NLL(y_vali_ten, mu, sigma, pi) validation_losses.append(loss.item()) if len(validation_losses) > 10 and validation_losses[-1] > np.mean(validation_losses[-10:]): break # Plot learning curves fig, ax = plt.subplots(1,1,figsize=(8,6)) plt.plot(training_losses, label="Training") plt.plot(validation_losses, label="Validation") plt.title("Learning curves for MDN") plt.xlabel("Epoch") plt.ylabel("Loss (NLL)") plt.legend() plt.show()
Making a Prediction
So now, we've trained a model that makes PDF estimates for our model. So how do we get a point estimate for each test data point, to compare the performance to the random forest? Well, we find the mode of each distribution and use this to be our prediction - this allows us to generate regression predictions despite being able to model inverse problems with MDNs.
# Make prediction for test set
mu, sigma, pi = MDN_model(X_test_ten)
# Returns [num_mixtures x num_bins] array of gaussian PDFs
def gaussian_mixture_pdf(mu, sigma, pi, bins):
n_mixtures = mu.shape[0]
n_bins = len(bins)
gmm_pdf = torch.zeros([n_mixtures, n_bins])
for k in range(num_gaussians):
gmm_pdf = gmm_pdf + pi[:, k].view(-1, 1) * gaussian_pdf(mu[:, k].view(-1, 1), sigma[:, k].view(-1, 1), bins)
return gmm_pdf
# Returns 2D Gaussian [num_distributions x num_bins]
def gaussian_pdf(mu, sigma, bins):
sqrt_2_pi = 2.5066283
val = 1 / (sigma * sqrt_2_pi) * torch.exp(-torch.pow(bins - mu, 2) / (2 * sigma * sigma))
return val
# Calculate PDF vector for each prediction
bin_min, bin_max = -3, 3
bins = torch.linspace(bin_min, bin_max, 1001)
pdfs = gaussian_mixture_pdf(mu, sigma, pi, bins)
# Find mode of each estimate
pdf_modes_inds = torch.argmax(pdfs, axis=1)
pdf_modes = bins[pdf_modes_inds]
# Un-normalise predictions
y_pred = pdf_modes * y_std + y_mean
y_pred = y_pred.detach().cpu().numpy()
# Measure prediction metrics
MDN_r2 = r2_score(y_test, y_pred)
MDN_mse = mean_squared_error(y_test, y_pred)
print("MSE: {:.4f}, R2: {:.4f}".format(MDN_mse, MDN_r2))
# Plot the estimated PDFs and corresponding modes
pdfs_np = pdfs.detach().cpu().numpy()
fig, axs = plt.subplots(1,2,figsize=(15,8),gridspec_kw={'width_ratios': [2, 5]})
bin_min_unnorm = bin_min * y_std.item() + y_mean.item()
bin_max_unnorm = bin_max * y_std.item() + y_mean.item()
axs[0].imshow(np.repeat(np.sqrt(pdfs_np), 15, axis=0), extent=[bin_min_unnorm,bin_max_unnorm,0,15])
axs[0].scatter(y_test, np.linspace(15,0,len(y_test)), c="m", marker="o", label="True")
axs[0].scatter(y_pred, np.linspace(15,0,len(y_pred)), c="r", marker="x", label="Pred")
axs[0].set_xlabel("True ln(Chl-a [g/m^2])")
axs[0].set_ylabel("Data point")
axs[0].set_title("MDN Predictions")
axs[0].legend()
# Plot prediction vs. real chlorophyll-a values
x_y_line = [np.min(y_test), np.max(y_test)]
axs[1].plot(x_y_line, x_y_line, "k--")
axs[1].scatter(y_test, y_pred)
axs[1].set_xlabel("True ln(Chl-a [g/m^2])")
axs[1].set_ylabel("Predicted ln(Chl-a [g/m^2])")
axs[1].set_title("MDN Predictions")
plt.show()
MSE: 0.3349, R2: 0.7640
The Results
Looking at MSE and R2, this model seems to perform similarly to the random forest - but we've only tried one seed, so this result is inconclusive. Repeated sampling would allow us to accurately compare the performance of these models, but that's out of the scope of this workbook.
We could use either of these models on each water pixel in a Sentinel-3 to estimate the concentration of Chlorophyll-a for each pixel. This would allow us to create heatmaps of Chlorophyll-a concentration like the one below (Pahlevan et. al, 2020) - but this is also out of scope for this notebook.
Improving the MDN Model Further
An issue with MDNs is that they tend to overfit, especially with few data points - and we only have ~300 data points. This may be why the MDN model doesn't substantially outperform the RF in this notebook. I tackle this problem in my paper "Semi-Supervised Conditional Density Estimation with Wasserstein Mixture Density Networks" published in AAAI 2022, where I use unlabelled data to regularise the learning process of MDNs and improve performance. In the paper, I apply this technique to remote sensing Chl-a datasets and show the utility of the method in this framework - but that's also out of the scope of this workbook.
I hope this notebook was useful and interesting. Feel free to message me at ogra439@aucklanduni.ac.nz if you have any questions regarding this work.