

# Algorithm-System Co-Design for TinyML

Ligeng Zhu

<u>ligeng@mit.edu</u> MIT



#### Today's Al is too BIG

#### Better model always comes with higher computational cost (vision)



#### Today's Al is too BIG

Better model always comes with higher computational cost (NLP)



**Cloud** → **Mobile** → **Tiny** 



**Cloud Al** 

GPUs/TPUs
ResNet



Data uploaded to the cloud for inference/training

**Cloud** → **Mobile** → **Tiny** 















**Cloud Al** 

GPUs/TPUs
ResNet

Mobile Al

Smartphones MobileNet Tiny Al

IoT/Microcontrollers

MCUNet

#### Squeezing deep learning into IoT devices

- Billions of IoT devices around the world based on microcontrollers
- Low-cost: low-income people can afford access. Democratize Al.
- Low-power: green AI, reduce carbon







Low-power (mW)

#### Squeezing deep learning into IoT devices

- Billions of IoT devices around the world based on microcontrollers
- Low-cost: low-income people can afford access. Democratize Al.
- Low-power: green AI, reduce carbon
- Various applications

**Smart Home** 



**Smart Manufacturing** 



Personalized Healthcare



Precise Agriculture



### TinyML is Challenging

#### Memory size is too small to hold DNNs



#### TinyML is Challenging

#### Memory size is too small to hold DNNs



#### Overview



MCUNet: Tiny Deep Learning on IoT Devices [Lin et al., NeurIPS 2020]

MCUNetV2: Memory-Efficient Patch-based Inference for Tiny Deep Learning [Lin et al., NeurIPS 2021]

On-Device Training Under 256KB Memory [Lin et al., NeurIPS 2022]

#### MCUNetV1 - Classification

#### Tiny vision application: visual wake words





Visual wake words dataset. [Chowdhery et al., arXiv 2019]

#### MCUNetV2: Detection

#### Advancing object detection by allowing a larger resolution





Face/mask detection



Person detection

#### Overview



MCUNet: Tiny Deep Learning on IoT Devices [Lin et al., NeurIPS 2020]

MCUNetV2: Memory-Efficient Patch-based Inference for Tiny Deep Learning [Lin et al., NeurIPS 2021]

On-Device Training Under 256KB Memory [Lin et al., NeurIPS 2022]

#### Overview

Co-design



MCUNet: Tiny Deep Learning on IoT Devices [Lin et al., NeurIPS 2020]

MCUNetV2: Memory-Efficient Patch-based Inference for Tiny Deep Learning [Lin et al., NeurIPS 2021]

On-Device Training Under 256KB Memory [Lin et al., NeurIPS 2022]

## Tiny On-Device Training

- Sparse Update
- Tiny Training Engine (TTE)

#### Can We Learn on the Edge?

From tinyML inference to training

Cloud-based Learning



- On-device learning:
  - ·customization by adapting to user data / life-long learning
  - ·better privacy, lower cost, empower AloT with limited connectivity

#### Can We Learn on the Edge?

#### From tinyML inference to training

#### A virtuous cycle:



- On-device learning:
  - ·customization by adapting to user data / life-long learning
  - better privacy, lower cost, empower AloT with limited connectivity
- Training is more expensive than inference
  - For example, store intermediate activation, extra back-propagation, etc.

 Edge devices have tight memory constraints. The training memory footprint of neural networks can easily exceed the limit.



• **Training** is more expensive than **inference** due to back-propagation, making it hard to fit IoT devices (e.g., MCU only has 256KB SRAM).



• **Training** is more expensive than **inference** due to back-propagation, making it hard to fit IoT devices (e.g., MCU only has 256KB SRAM).





1. Quantization-aware scaling



2. Sparse layer/tensor update



3. Tiny Training Engine



1. Quantization-aware scaling



2. Sparse layer/tensor update



3. Tiny Training Engine

Real quantized graphs save memory, but are hard to quantize



Most intermediate tensors are **still in FP32** format in fake quantization, thus cannot save memory footprint

Real quantized graphs save memory, but are hard to quantize



All tensors are in **int8/int32 format** for real quantization, thus save memory footprint, but leading to optimization difficulty

#### Quantized graphs save memory, but are hard to quantize



#### Difficult optimize:

- Mixed precisions: int8/int32/fp32...
- Lack BatchNorm

Performance Comparison (average on 10 datasets)



#### Quantization leads to distorted gradient magnitudes

- Why is the training convergence worse?
- The scale of weight and gradients does not match in *real* quantized training!



Tensor Index

#### QAS addresses the optimization difficulty of quantized graphs

Quantization overview

$$\mathbf{\bar{y}}_{\text{int8}} = \text{cast2int8}[s_{\text{fp32}} \cdot (\mathbf{\bar{W}}_{\text{int8}}\mathbf{\bar{x}}_{\text{int8}} + \mathbf{\bar{b}}_{\text{int32}})],$$

Per Channel scaling

$$\mathbf{W} = s_{\mathbf{W}} \cdot (\mathbf{W}/s_{\mathbf{W}}) \stackrel{\text{quantize}}{\approx} s_{\mathbf{W}} \cdot \mathbf{\bar{W}}, \quad \mathbf{G}_{\mathbf{\bar{W}}} \approx s_{\mathbf{W}} \cdot \mathbf{G}_{\mathbf{W}},$$

Weight and gradient ratios are off by  $S_{\mathbf{W}}^{-2}$ 

$$\|\mathbf{\bar{W}}\|/\|\mathbf{G}_{\mathbf{\bar{W}}}\| \approx \|\mathbf{W}/s_{\mathbf{W}}\|/\|s_{\mathbf{W}}\cdot\mathbf{G}_{\mathbf{W}}\| = s_{\mathbf{W}}^{-2} \cdot \|\mathbf{W}\|/\|\mathbf{G}\|.$$

Thus, re-scale the gradients

$$\tilde{\mathbf{G}}_{\bar{\mathbf{W}}} = \mathbf{G}_{\bar{\mathbf{W}}} \cdot s_{\mathbf{W}}^{-2}, \quad \tilde{\mathbf{G}}_{\bar{\mathbf{b}}} = \mathbf{G}_{\bar{\mathbf{b}}} \cdot s_{\mathbf{W}}^{-2} \cdot s_{\mathbf{x}}^{-2} = \mathbf{G}_{\bar{\mathbf{b}}} \cdot s^{-2}$$

QAS addresses the optimization difficulty of quantized graphs

$$\tilde{\mathbf{G}}_{\bar{\mathbf{W}}} = \mathbf{G}_{\bar{\mathbf{W}}} \left( s_{\mathbf{W}}^{-2}, \mathbf{G}_{\bar{\mathbf{b}}} \right) = \mathbf{G}_{\bar{\mathbf{b}}} \cdot s_{\mathbf{W}}^{-2} \cdot s_{\mathbf{x}}^{-2} = \mathbf{G}_{\bar{\mathbf{b}}} \left( s_{\mathbf{W}}^{-2} \right) \cdot s_{\mathbf{w}}^{-2}$$



Tensor Index

QAS addresses the optimization difficulty of quantized graphs



After applying QAS, the convergence of real quantized is stable.

QAS addresses the optimization difficulty of quantized graphs



QAS improves the accuracy over naive int8 training, and shows no inferior performance than fp32 results.



1. Quantization-aware scaling



2. Sparse layer/tensor update



3. Tiny Training Engine

Question: Why training memory is much larger than inference?

Answer: Because of intermediate activations

Forward: 
$$\mathbf{a}_{i+1} = \mathbf{a}_i \mathbf{W}_i + \mathbf{b}_i$$

Backward: 
$$\frac{\partial L}{\partial \mathbf{W}_i} = \mathbf{a}_i^T \frac{\partial L}{\partial \mathbf{a}_{i+1}}$$

- Inference does not need to store activations, training does.
- Activations grows linearly with batch size, which is always 1 for inference.
- Even with bs=1, activations are usually larger than model weights.



Activation is the main bottleneck for on-device learning, not parameters.



- Activation is the main bottleneck for on-device learning, not parameters.
- Previous methods focus on reducing the number of parameters or FLOPs, while the main bottleneck does not improve much.

#### 2. Sparse Layer/Tensor Update

#### Full update



Model: ProxylessNAS-Mobile

Updating the whole model is too expensive:

- Need to save all intermediate activation (quite large)
- Need to store the updated weights in SRAM (Flash is read-only)



#### 2. Sparse Layer/Tensor Update

#### Last layer update



Model: ProxylessNAS-Mobile

Updating only the last cheap

- No need to back propagating to previous layers
- But the accuracy is low and not ideal.



#### **Bias-only update**



Updating the only the bias part

- No need to store the activations
- Back propagating to the first layer.

Forward:  $\mathbf{a}_{i+1} = \mathbf{a}_i \mathbf{W}_i + \mathbf{b}_i$ 

Backward:  $\frac{\partial L}{\partial \mathbf{W}_i} = \mathbf{a}_i^T \frac{\partial L}{\partial \mathbf{a}_{i+1}}, \qquad \frac{\partial L}{\partial \mathbf{b}_i} = \frac{\partial L}{\partial \mathbf{a}_{i+1}} = \frac{\partial L}{\partial \mathbf{a}_{i+2}} \mathbf{W}_{i+1}^T$ 

$$\frac{\partial L}{\partial \mathbf{b}_{i}} = \frac{\partial L}{\partial \mathbf{a}_{i+1}} = \frac{\partial L}{\partial \mathbf{a}_{i+2}} \mathbf{W}_{i+1}^{T}$$



Updated synapses are sparse



#### Some layers are more important than others



#### Some layers are more important than others



1. Later layers contribute more to the accuracy.

#### Some layers are more important than others



- 1. Later layers contribute more to the accuracy.
- 2. First point-wise conv are more important to accuracy.

#### Some layers are more important than others



- 1. Later layers contribute more to the accuracy.
- 2. First point-wise conv are more important to accuracy.
- 3. The more channels being updated, the higher the accuracy.

#### **Sparse Layer/Tensor Update**



#### **Sparse Layer/Tensor Update**



• Sparse layer update: no need to store activation

Sparse Layer/Tensor Update



Sparse layer backpropagation

Sparse tensor backpropagation

- Sparse layer update: no need to store activation
- Sparse tensor update: only store a subset of the activations.



Sparse Layer/Tensor Update



- Sparse layer update: no need to store activation
- Sparse tensor update: only store a subset of the activations.
- Sparse update: no need to back propagate the early layers

Sparse Layer/Tensor Update



Backpropagation stops here

Sparse layer backpropagation

Sparse tensor backpropagation



### Update Paradigms Comparison



### On-Device Training Under 256KB Memory



1. Quantization-aware scaling



2. Sparse layer/tensor update



3. Tiny Training Engine

#### **Existing frameworks cannot fit**

- Runtime is heavy
  - Heavy dependencies and large binary size (>100MB static memory)
  - Auto-diff at runtime; low edge efficiency
- Memory is heavy
  - A lot of intermediate (and unused) buffers
  - Has to compute full gradients



Workflow of conventional training engine



Workflow of conventional training engine



Workflow of conventional training engine



#### Workflow of conventional training engine





Conventional training framework focus on **flexibility**, and the auto-diff is performed at **runtime**.

Thus, any optimizations will lead to runtime overhead.

#### TTE: Move workload from runtime to compile time





TTE moves most workload from runtime to compile-time, thus minimizes the runtime overhead, also enables opportunities for extensive graph optimizations.



%6 = sum(%grad, axis=-1);

(83, 85, 86)

Example from a matrix multiplication with full update

```
Forward
```

```
y = mul(x, w) + b
```

#### Backward

```
dy/dx = mul(G, w)dy/dw = mul(G^{T}, X)dy/db = sum(G)
```

```
%0 = multiply(%x, %weight);
%1 = add(%0, %bias);

# backward
%3 = multiply(%grad, %weight); ====> dy / dx
%4 = transpose(%grad);
%5 = multiply(%4, %x); ====> dy / dw
```

====> dy / db



**Forward** 

Backward



Remove unnecessary computations from DAG via dependency analysis and dead-code elimination.



Freely annotate **ANY** parameters

TTE will trim the computation accordingly.

```
■ updated □ fixed
                        \mathbf{b}_i
                             \mathbf{W}_{i+1} \mathbf{b}_{i+1}
                                                   (a) full update
                                        (b) bias-only update
                                                            (c) sparse layer update
                                                                                (d) sparse tensor update
fn (%x: Tensor[(10, 10), float32],
                                                       fn (%x: Tensor[(10, 10), float32, needs_grad=True],
                                                           %weight: Tensor[(20, 10), float32, needs_grad=0.5],
    %weight: Tensor[(10, 10), float32],
                                                           %bias: Tensor[(20), float32, needs_grad=True],
    %bias: Tensor[(10), float32]),
    %grad: Tensor[(10), float32]),
                                                           %grad: Tensor[(10, 20), float32]),
                                  Automatically remove
  # forward
                                                         # forward
                                  the buffers of pruned
  %0 = multiply(%x, %weight);
                                                         %0 = multiply(%x, %weight);
                                    gradients from the
  %1 = add(%0, %bias);
                                                         %0.1 = slice(%x, begin=[0, 0], ends=[10, 10]);
                                   computation graph.
  # backward
                                                         %1 = add(%0, %bias);
  %3 = multiply(%grad, %weight);
                                                         # backward
  %4 = transpose(%grad)
                                                         %3 = multiply(%grad, %weight);
  %5 = \text{multiply}(%4, %x);
                                                         %4 = transpose(%grad)
  %6 = sum(%grad, axis=-1);
                                                         %5 = \text{multiply}(%4, %0.1);
                                                         %6 = sum(%grad, axis=-1);
  (%3, %5, %6)
                                                         (83, 85, 86)
```

#### Sparse update results



- Tiny Training Engine supports backward graph pruning and sparse update at IR-level.
- After graph pruning, un-used weights and sub-tensors are pruned from DAG => 6.5-8.7x memory saving

#### Re-ordering reduces memory footprint

- Tiny Training Engine supports backward graph pruning and sparse update at IR-level.
- After graph pruning, un-used weights and sub-tensors are pruned from DAG => 6.5-8.7x memory saving



(a) Conventional way to update parameters

F: Forward, B: Backward, U: Update

#### Re-ordering reduces memory footprint

- Tiny Training Engine supports backward graph pruning and sparse update at IR-level.
- After graph pruning, un-used weights and sub-tensors are pruned from DAG => 6.5-8.7x memory saving





(a) Conventional way to update parameters

(b) Operator re-ordering

Operator life-cycle analysis reveals the **memory** redundancy in the optimization step.

After re-ordering, the redundant memory usage is eliminated from training.

F: Forward, B: Backward, U: Update

#### Re-ordering reduces memory footprint



Operator life-cycle analysis shows memory footprint can be greatly reduced by operator re-ordering.

#### Smaller memory usage, faster training speed



20x smaller memory



23x faster speed

### Tiny Training

#### **Co-design Results**



Co-design reduces the training memory by 2300x times with the same transfer accuracy.

The numbers are measured with MobilenetV2-w0.35, batch size 1 and resolution 128x128.



https://www.bilibili.com/video/BV1qv4y1d7MV/



https://youtu.be/XaDCO8YtmBw



### **Extending TTE to More Platforms**

#### Accelerate on-device training on diverse edge hardware

- We extend TTE to support:
  - Diverse models (CNN + Transformers)
  - Diverse frontends
    - PyTorch
    - TensorFlow
    - Jax
  - Diverse hardware backends
    - Apple M1
    - Raspberry Pi
    - Smartphones
    - •



### **Extending TTE to More Platforms**

#### Consistently speed up training on diverse platforms

 TTE provides a systematic support for sparse update schemes for vision and NLP models, leading to consistent memory saving at the same training accuracy





Results measured on Raspberry Pi 4B+.

### Media







(Homepage highlight)

(Homepage highlight)

MCUNet: Tiny Deep Learning on IoT Devices [Lin et al., NeurIPS 2020]

MCUNetV2: Memory-Efficient Patch-based Inference for Tiny Deep Learning [Lin et al., NeurIPS 2021]

On-Device Training Under 256KB Memory [Lin et al., NeurIPS 2022]



### **Future Work**

- Scale up to LLM/foundation models
  - LLM models are hard to serve/fine-tune due to the huge model size
  - GPU memories are not enough to serve 100 billion-parameter models
  - Our techniques help democratize LLMs (e.g., quantization, sparse update, system support)
- Collaboration welcome!

