forked from lbferreira/python_lecture
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathraster_dataset.py
More file actions
70 lines (61 loc) · 2.17 KB
/
raster_dataset.py
File metadata and controls
70 lines (61 loc) · 2.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from __future__ import annotations
from typing import Any, List
import numpy as np
import rasterio
from affine import Affine
class RasterDataset:
def __init__(
self,
data: np.ndarray,
crs: Any,
transform: Affine,
band_names: List[str],
nodata: Any,
) -> None:
"""Stores raster data and related metadata.
Args:
data (np.ndarray): raster data
crs (Any): Coordinate Reference System.
transform (Affine): Affine transformation matrix.
band_names (List[str]): band names.
nodata (Any): nodata value.
"""
self.data = data
self.crs = crs
self.transform = transform
self.band_names = band_names
self.nodata = nodata
@staticmethod
def from_geotiff(file: str) -> RasterDataset:
"""Reads a GeoTIFF file and returns a RasterDataset object."""
with rasterio.open(file) as src:
data = src.read()
band_names = src.descriptions
crs = src.crs
transform = src.transform
nodata = src.nodata
return RasterDataset(data, crs, transform, band_names, nodata)
def get_band_data(self, band_name: str) -> np.ndarray:
"""Returns the data for the specified band."""
# Here we could implement some logic to check if the requested band
# is available. If not, we could raise an exception.
return self.data[self._get_band_index(band_name)]
def _get_band_index(self, band_name: str) -> int:
"""Returns the index of the specified band."""
return self.band_names.index(band_name)
def to_geotiff(self, file: str) -> None:
"""Writes the raster data to a GeoTIFF file."""
with rasterio.open(
file,
"w",
driver="GTiff",
height=self.data.shape[1],
width=self.data.shape[2],
count=self.data.shape[0],
dtype=self.data.dtype,
crs=self.crs,
transform=self.transform,
nodata=self.nodata,
) as dst:
dst.write(self.data)
dst.descriptions = tuple(self.band_names)