Land Use Clustering#

datashaderholoviewsbokehdaskdask-ml
Published: November 26, 2018 · Modified: November 2, 2023


Spectral Clustering Example#

The image loaded here is a cropped portion of a LANDSAT image of Walker Lake.

In addition to dask-ml, we’ll use rasterio to read the data and matplotlib to plot the figures. I’m just working on my laptop, so we could use either the threaded or distributed scheduler, but here I’ll use the distributed scheduler for the diagnostics.

import holoviews as hv
from holoviews import opts
from holoviews.operation.datashader import regrid
import cartopy.crs as ccrs
import dask.array as da
#from dask_ml.cluster import SpectralClustering
from dask.distributed import Client
hv.extension('bokeh')
import dask_ml
dask_ml.__version__
'2023.3.24'
from dask_ml.cluster import SpectralClustering
client = Client(processes=False)
#client = Client(n_workers=8, threads_per_worker=1)
client

Client

Client-4085e221-7907-11ee-8aed-6045bd7a96ff

Connection method: Cluster object Cluster type: distributed.LocalCluster
Dashboard: http://10.1.0.25:8787/status

Cluster Info

import intake
cat = intake.open_catalog('./catalog.yml')
list(cat)
['landsat_5']
landsat_5_img = cat.landsat_5.read_chunked()
landsat_5_img
2023-11-01 22:37:35,278 - intake - WARNING - cache.py:_download:L264 - Cache progress bar in a notebook requires ipywidgets to be installed: conda/pip install ipywidgets
2023-11-01 22:37:35,287 - intake - WARNING - cache.py:_download:L264 - Cache progress bar in a notebook requires ipywidgets to be installed: conda/pip install ipywidgets
2023-11-01 22:37:35,287 - intake - WARNING - cache.py:_download:L264 - Cache progress bar in a notebook requires ipywidgets to be installed: conda/pip install ipywidgets
2023-11-01 22:37:35,300 - intake - WARNING - cache.py:_download:L264 - Cache progress bar in a notebook requires ipywidgets to be installed: conda/pip install ipywidgets
2023-11-01 22:37:39,281 - intake - WARNING - cache.py:_download:L264 - Cache progress bar in a notebook requires ipywidgets to be installed: conda/pip install ipywidgets
2023-11-01 22:37:39,466 - intake - WARNING - cache.py:_download:L264 - Cache progress bar in a notebook requires ipywidgets to be installed: conda/pip install ipywidgets
<xarray.DataArray (band: 6, y: 7241, x: 7961)>
dask.array<concatenate, shape=(6, 7241, 7961), dtype=int16, chunksize=(1, 256, 256), chunktype=numpy.ndarray>
Coordinates:
  * x            (x) float64 2.424e+05 2.424e+05 ... 4.812e+05 4.812e+05
  * y            (y) float64 4.414e+06 4.414e+06 ... 4.197e+06 4.197e+06
    spatial_ref  int64 0
  * band         (band) int64 1 2 3 4 5 7
Attributes:
    AREA_OR_POINT:  Area
    Band_1:         band 1 surface reflectance
    _FillValue:     -9999
    scale_factor:   1.0
    add_offset:     0.0
    long_name:      band 1 surface reflectance
crs = ccrs.epsg(32611)
x_center, y_center = crs.transform_point(-118.7081, 38.6942, ccrs.PlateCarree())
buffer = 1.7e4

xmin = x_center - buffer
xmax = x_center + buffer
ymin = y_center - buffer
ymax = y_center + buffer

ROI = landsat_5_img.sel(x=slice(xmin, xmax), y=slice(ymax, ymin))
ROI = ROI.where(ROI > ROI.attrs['_FillValue'])
bands = ROI.astype(float)
bands = (bands - bands.mean()) / bands.std()
bands
<xarray.DataArray (band: 6, y: 1134, x: 1133)>
dask.array<truediv, shape=(6, 1134, 1133), dtype=float64, chunksize=(1, 256, 256), chunktype=numpy.ndarray>
Coordinates:
  * x            (x) float64 3.345e+05 3.345e+05 ... 3.684e+05 3.684e+05
  * y            (y) float64 4.301e+06 4.301e+06 ... 4.267e+06 4.267e+06
    spatial_ref  int64 0
  * band         (band) int64 1 2 3 4 5 7
opts.defaults(
    opts.Image(invert_yaxis=True, width=250, height=250, tools=['hover'], cmap='viridis'))
hv.Layout([regrid(hv.Image(band, kdims=['x', 'y'])) for band in bands[:3]])
flat_input = bands.stack(z=('y', 'x'))
flat_input
<xarray.DataArray (band: 6, z: 1284822)>
dask.array<reshape, shape=(6, 1284822), dtype=float64, chunksize=(1, 52118), chunktype=numpy.ndarray>
Coordinates:
    spatial_ref  int64 0
  * band         (band) int64 1 2 3 4 5 7
  * z            (z) object MultiIndex
  * y            (z) float64 4.301e+06 4.301e+06 ... 4.267e+06 4.267e+06
  * x            (z) float64 3.345e+05 3.345e+05 ... 3.684e+05 3.684e+05
flat_input.shape
(6, 1284822)

We’ll reshape the image to be how dask-ml / scikit-learn expect it: (n_samples, n_features) where n_features is 1 in this case. Then we’ll persist that in memory. We still have a small dataset at this point. The large dataset, which dask helps us manage, is the intermediate n_samples x n_samples array that spectral clustering operates on. For our 2,500 x 2,500 pixel subset, that’s ~50

X = flat_input.values.astype('float').T
X.shape
(1284822, 6)
X = da.from_array(X, chunks=100_000)
X = client.persist(X)

And we’ll fit the estimator.

clf = SpectralClustering(n_clusters=4, random_state=0,
                         gamma=None,
                         kmeans_params={'init_max_iter': 5},
                         persist_embedding=True)
%time clf.fit(X)
CPU times: user 27.3 s, sys: 17.5 s, total: 44.8 s
Wall time: 27.2 s
SpectralClustering(gamma=None, kmeans_params={'init_max_iter': 5}, n_clusters=4,
                   persist_embedding=True, random_state=0)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
labels = clf.assign_labels_.labels_.compute()
labels.shape
(1284822,)
labels = labels.reshape(bands[0].shape)
hv.Layout([regrid(hv.Image(band, kdims=['x', 'y'])) for band in bands]) 
hv.Layout([regrid(hv.Image(band, kdims=['x', 'y'])) for band in bands[3:]])
hv.Image(labels)
This web page was generated from a Jupyter notebook and not all interactivity will work on this website.