Post

UNet

Explore UNet...

What is UNet

Upsample

  • Explanation of nn.Upsample(scale_factor=2, mode=’bilinear’, align_corners=True) ok so if align_corners = False, I don’t know how to calculate it yet lol. but if align_corners = False, it’s pretty easy. [1, 2 3, 4] –> [1, a, b, 2 c, d, e, f g, h, i, j 3, k, l, 4]

\rightarrow voi row 1 thi tinh duoc gap = (2 - 1) / 3 = 0.3333 -> 1 -> 1.3333 -> 1.6666 -> 2 tuong tu tinh cho row cuoi xong roi cot 1 cung tinh nhu vay, cot cuoi cung tinh nhu vay va roi nhung thang o giua cung duoc tinh nhu vay, dua vao nhung thang da co

  • in pytorch, they use ‘dim’ instead of ‘axis’ like in tf or np

  • use “from .unet_parts.py import *” when the unet_parts.py is in the same package as unet_model.py”: project/
    ├── main.py
    ├── mypackage/
    │ ├── init.py
    │ ├── unet_parts.py
    │ ├── unet_model.py # Contains “from .unet_parts import *”
  • “from unet_parts.py import *” Use this for absolute imports, which require unet_parts to be accessible as a top-level module in Python’s import path. This can break in certain contexts (e.g., when running as part of a larger package).
  • a package (mot goi) contains files, or subpackages

  • regarding bilinear, factor, just jot down some drafts about down4 and up1’ sizes and see and understand.

  • n_classes in this unet architecture is classes of pixel right? for example of 2 classes are black and white… Chatgpt: Yes, n_classes in this UNet architecture refers to the number of output classes for each pixel in the image (pixel-level classes). Essentially, it defines the number of different categories or labels into which each pixel in the input image can be classified.

In a segmentation model like UNet:

The network produces an output with a shape of (batch_size, n_classes, height, width). This means that for every pixel in the input image, the model outputs probabilities (a vector of size n_classes) representing the likelihood of that pixel belonging to each class.

use_checkpointing

The use_checkpointing function applies activation checkpointing in PyTorch to reduce memory usage during the forward pass of the U-Net model.

How does activation checkpointing work? During the forward pass, PyTorch normally stores all activations (intermediate tensors) in memory because they are needed to compute gradients during the backward pass. With activation checkpointing, instead of storing all intermediate activations, some intermediate activations can be recomputed during the backward pass to reduce memory usage. This trades memory for computation: you save GPU memory during training but spend more computational resources in the backward pass to recompute activations.

magic in unet:

  • use elastic deformation data augmentation in bio images is a brilliant idea.
  • no dense layer used
This post is licensed under CC BY 4.0 by the author.