|
26 | 26 | "outputs": [], |
27 | 27 | "source": [ |
28 | 28 | "import xarray as xr\n", |
| 29 | + "\n", |
29 | 30 | "import xbatcher" |
30 | 31 | ] |
31 | 32 | }, |
|
46 | 47 | "metadata": {}, |
47 | 48 | "outputs": [], |
48 | 49 | "source": [ |
49 | | - "store = \"az://carbonplan-share/example_cmip6_data.zarr\"\n", |
| 50 | + "store = 'az://carbonplan-share/example_cmip6_data.zarr'\n", |
50 | 51 | "ds = xr.open_dataset(\n", |
51 | 52 | " store,\n", |
52 | | - " engine=\"zarr\",\n", |
| 53 | + " engine='zarr',\n", |
53 | 54 | " chunks={},\n", |
54 | | - " backend_kwargs={\"storage_options\": {\"account_name\": \"carbonplan\"}},\n", |
| 55 | + " backend_kwargs={'storage_options': {'account_name': 'carbonplan'}},\n", |
55 | 56 | ")\n", |
56 | 57 | "\n", |
57 | 58 | "# the attributes contain a lot of useful information, but clutter the print out when we inspect the outputs\n", |
|
98 | 99 | "\n", |
99 | 100 | "bgen = xbatcher.BatchGenerator(\n", |
100 | 101 | " ds=ds,\n", |
101 | | - " input_dims={\"time\": n_timepoint_in_each_sample},\n", |
| 102 | + " input_dims={'time': n_timepoint_in_each_sample},\n", |
102 | 103 | ")\n", |
103 | 104 | "\n", |
104 | | - "print(f\"{len(bgen)} batches\")" |
| 105 | + "print(f'{len(bgen)} batches')" |
105 | 106 | ] |
106 | 107 | }, |
107 | 108 | { |
|
133 | 134 | "outputs": [], |
134 | 135 | "source": [ |
135 | 136 | "expected_n_batch = len(ds.time) / n_timepoint_in_each_sample\n", |
136 | | - "print(f\"Expecting {expected_n_batch} batches, getting {len(bgen)} batches\")" |
| 137 | + "print(f'Expecting {expected_n_batch} batches, getting {len(bgen)} batches')" |
137 | 138 | ] |
138 | 139 | }, |
139 | 140 | { |
|
153 | 154 | "source": [ |
154 | 155 | "expected_batch_size = len(ds.lat) * len(ds.lon)\n", |
155 | 156 | "print(\n", |
156 | | - " f\"Expecting {expected_batch_size} samples per batch, getting {len(batch.sample)} samples per batch\"\n", |
| 157 | + " f'Expecting {expected_batch_size} samples per batch, getting {len(batch.sample)} samples per batch'\n", |
157 | 158 | ")" |
158 | 159 | ] |
159 | 160 | }, |
|
179 | 180 | "\n", |
180 | 181 | "bgen = xbatcher.BatchGenerator(\n", |
181 | 182 | " ds=ds,\n", |
182 | | - " input_dims={\"time\": n_timepoint_in_each_sample},\n", |
183 | | - " batch_dims={\"time\": n_timepoint_in_each_batch},\n", |
| 183 | + " input_dims={'time': n_timepoint_in_each_sample},\n", |
| 184 | + " batch_dims={'time': n_timepoint_in_each_batch},\n", |
184 | 185 | " concat_input_dims=True,\n", |
185 | 186 | ")\n", |
186 | 187 | "\n", |
187 | | - "print(f\"{len(bgen)} batches\")" |
| 188 | + "print(f'{len(bgen)} batches')" |
188 | 189 | ] |
189 | 190 | }, |
190 | 191 | { |
|
217 | 218 | "source": [ |
218 | 219 | "n_timepoint_in_batch = 31\n", |
219 | 220 | "\n", |
220 | | - "bgen = xbatcher.BatchGenerator(ds=ds, input_dims={\"time\": n_timepoint_in_batch})\n", |
| 221 | + "bgen = xbatcher.BatchGenerator(ds=ds, input_dims={'time': n_timepoint_in_batch})\n", |
221 | 222 | "\n", |
222 | 223 | "for batch in bgen:\n", |
223 | | - " print(f\"last time point in ds is {ds.time[-1].values}\")\n", |
224 | | - " print(f\"last time point in batch is {batch.time[-1].values}\")\n", |
| 224 | + " print(f'last time point in ds is {ds.time[-1].values}')\n", |
| 225 | + " print(f'last time point in batch is {batch.time[-1].values}')\n", |
225 | 226 | "batch" |
226 | 227 | ] |
227 | 228 | }, |
|
249 | 250 | "\n", |
250 | 251 | "bgen = xbatcher.BatchGenerator(\n", |
251 | 252 | " ds=ds,\n", |
252 | | - " input_dims={\"time\": n_timepoint_in_each_sample},\n", |
253 | | - " batch_dims={\"time\": n_timepoint_in_each_batch},\n", |
| 253 | + " input_dims={'time': n_timepoint_in_each_sample},\n", |
| 254 | + " batch_dims={'time': n_timepoint_in_each_batch},\n", |
254 | 255 | " concat_input_dims=True,\n", |
255 | | - " input_overlap={\"time\": input_overlap},\n", |
| 256 | + " input_overlap={'time': input_overlap},\n", |
256 | 257 | ")\n", |
257 | 258 | "\n", |
258 | 259 | "batch = bgen[0]\n", |
259 | 260 | "\n", |
260 | | - "print(f\"{len(bgen)} batches\")\n", |
| 261 | + "print(f'{len(bgen)} batches')\n", |
261 | 262 | "batch" |
262 | 263 | ] |
263 | 264 | }, |
|
283 | 284 | "display(pixel)\n", |
284 | 285 | "\n", |
285 | 286 | "print(\n", |
286 | | - " f\"sample 1 goes from {pixel.isel(input_batch=0).time[0].values} to {pixel.isel(input_batch=0).time[-1].values}\"\n", |
| 287 | + " f'sample 1 goes from {pixel.isel(input_batch=0).time[0].values} to {pixel.isel(input_batch=0).time[-1].values}'\n", |
287 | 288 | ")\n", |
288 | 289 | "print(\n", |
289 | | - " f\"sample 2 goes from {pixel.isel(input_batch=1).time[0].values} to {pixel.isel(input_batch=1).time[-1].values}\"\n", |
| 290 | + " f'sample 2 goes from {pixel.isel(input_batch=1).time[0].values} to {pixel.isel(input_batch=1).time[-1].values}'\n", |
290 | 291 | ")" |
291 | 292 | ] |
292 | 293 | }, |
|
310 | 311 | "outputs": [], |
311 | 312 | "source": [ |
312 | 313 | "bgen = xbatcher.BatchGenerator(\n", |
313 | | - " ds=ds[[\"tasmax\"]].isel(lat=slice(0, 18), lon=slice(0, 18), time=slice(0, 30)),\n", |
314 | | - " input_dims={\"lat\": 9, \"lon\": 9, \"time\": 10},\n", |
315 | | - " batch_dims={\"lat\": 18, \"lon\": 18, \"time\": 15},\n", |
| 314 | + " ds=ds[['tasmax']].isel(lat=slice(0, 18), lon=slice(0, 18), time=slice(0, 30)),\n", |
| 315 | + " input_dims={'lat': 9, 'lon': 9, 'time': 10},\n", |
| 316 | + " batch_dims={'lat': 18, 'lon': 18, 'time': 15},\n", |
316 | 317 | " concat_input_dims=True,\n", |
317 | | - " input_overlap={\"lat\": 8, \"lon\": 8, \"time\": 9},\n", |
| 318 | + " input_overlap={'lat': 8, 'lon': 8, 'time': 9},\n", |
318 | 319 | ")\n", |
319 | 320 | "\n", |
320 | 321 | "for i, batch in enumerate(bgen):\n", |
321 | | - " print(f\"batch {i}\")\n", |
| 322 | + " print(f'batch {i}')\n", |
322 | 323 | " # make sure the ordering of dimension is consistent\n", |
323 | | - " batch = batch.transpose(\"input_batch\", \"lat_input\", \"lon_input\", \"time_input\")\n", |
| 324 | + " batch = batch.transpose('input_batch', 'lat_input', 'lon_input', 'time_input')\n", |
324 | 325 | "\n", |
325 | 326 | " # only use the first 9 time points as features, since the last time point is the label to be predicted\n", |
326 | 327 | " features = batch.tasmax.isel(time_input=slice(0, 9))\n", |
327 | 328 | " # select the center pixel at the last time point to be the label to be predicted\n", |
328 | 329 | " # the actual lat/lon/time for each of the sample can be accessed in labels.coords\n", |
329 | 330 | " labels = batch.tasmax.isel(lat_input=5, lon_input=5, time_input=9)\n", |
330 | 331 | "\n", |
331 | | - " print(\"feature shape\", features.shape)\n", |
332 | | - " print(\"label shape\", labels.shape)\n", |
333 | | - " print(\"shape of lat of each sample\", labels.coords[\"lat\"].shape)\n", |
334 | | - " print(\"\")" |
| 332 | + " print('feature shape', features.shape)\n", |
| 333 | + " print('label shape', labels.shape)\n", |
| 334 | + " print('shape of lat of each sample', labels.coords['lat'].shape)\n", |
| 335 | + " print('')" |
335 | 336 | ] |
336 | 337 | }, |
337 | 338 | { |
|
350 | 351 | "outputs": [], |
351 | 352 | "source": [ |
352 | 353 | "for i, batch in enumerate(bgen):\n", |
353 | | - " print(f\"batch {i}\")\n", |
| 354 | + " print(f'batch {i}')\n", |
354 | 355 | " # make sure the ordering of dimension is consistent\n", |
355 | | - " batch = batch.transpose(\"input_batch\", \"lat_input\", \"lon_input\", \"time_input\")\n", |
| 356 | + " batch = batch.transpose('input_batch', 'lat_input', 'lon_input', 'time_input')\n", |
356 | 357 | "\n", |
357 | 358 | " # only use the first 9 time points as features, since the last time point is the label to be predicted\n", |
358 | 359 | " features = batch.tasmax.isel(time_input=slice(0, 9))\n", |
359 | | - " features = features.stack(features=[\"lat_input\", \"lon_input\", \"time_input\"])\n", |
| 360 | + " features = features.stack(features=['lat_input', 'lon_input', 'time_input'])\n", |
360 | 361 | "\n", |
361 | 362 | " # select the center pixel at the last time point to be the label to be predicted\n", |
362 | 363 | " # the actual lat/lon/time for each of the sample can be accessed in labels.coords\n", |
363 | 364 | " labels = batch.tasmax.isel(lat_input=5, lon_input=5, time_input=9)\n", |
364 | 365 | "\n", |
365 | | - " print(\"feature shape\", features.shape)\n", |
366 | | - " print(\"label shape\", labels.shape)\n", |
367 | | - " print(\"shape of lat of each sample\", labels.coords[\"lat\"].shape, \"\\n\")" |
| 366 | + " print('feature shape', features.shape)\n", |
| 367 | + " print('label shape', labels.shape)\n", |
| 368 | + " print('shape of lat of each sample', labels.coords['lat'].shape, '\\n')" |
368 | 369 | ] |
369 | 370 | }, |
370 | 371 | { |
|
0 commit comments