bowang-lab/EchoJEPA: Joint-embedding predictive architectures (JEPAs) for medical video.


Alif Munim, Adibvafa Fallahpour, Teodora Szasz, Ahmadreza Attarpour, River Jiang, Brana Sooriyakanthan, Maala Sooriyakanthan, Heather Whitney, Jeremy Slivnick, Barry Rubin, Wendy Tsang, Bo Wang

[Paper] [Website] [BibTex]

Official Pytorch codebase for EchoJEPA.

Foundation models for echocardiography often struggle to disentangle anatomical signal from the stochastic speckle and acquisition artifacts inherent to ultrasound. We present EchoJEPA, a foundation model trained on 18 million echocardiograms across 300K patients, representing the largest pretraining corpus for this modality to date. By leveraging a latent predictive objective, EchoJEPA learns robust anatomical representations that ignore speckle noise. We validate this using a novel multi-view probing framework with frozen backbones, where EchoJEPA outperforms state-of-the-art baselines by approximately 20% in left ventricular ejection fraction (LVEF) estimation and 17% in right ventricular systolic pressure (RVSP) estimation. The model also exhibits remarkable sample efficiency, reaching 79% view classification accuracy with only 1% of labeled data versus 42% for the best baseline trained on 100%. Crucially, EchoJEPA demonstrates superior generalization, degrading by only 2% under physics-informed acoustic perturbations compared to 17% for competitors. Most remarkably, its zero-shot performance on pediatric patients surpasses fully fine-tuned baselines, establishing latent prediction as a superior paradigm for robust, generalizable medical AI.

EchoJEPA models trained on just 1% of labeled data outperform baselines trained on 100%. This efficiency implies that latent prediction yields dense representations capable of defining the view manifold with minimal supervision, as evidenced by the distinct anatomical clustering in the figure below.

EchoJEPA demonstrates anatomical localization, focusing on the mitral valve leaflets, ventricular walls, and annulus while ignoring sector background. Received attention clusters at Doppler jet edges while given attention localizes on valve structures generating flow. Across the cardiac cycle, focus shifts from valve tips during opening to chamber walls during relaxation, indicating it interprets the echocardiogram as a functional biological system.

conda create -n vjepa2-312 python=3.12
conda activate vjepa2-312
pip install .  # or `pip install -e .` for development mode

Pretraining can also be run locally or distributed. Pretraining and cooldown training phases are run with the same command using different configs. These sample commands launch initial training of a ViT-L model on MIMIC-IV-ECHO, a dataset of 525K echocardiograms which can be accessed through PhysioNet.

python -m app.main --fname configs/train/vitl16/pretrain-mimic-224px-16f.yaml \
  --devices cuda:0
python -m app.main_distributed \
  --fname configs/train/vitl16/pretrain-mimic-224px-16f.yaml
  --time 6000
  --account my_account --qos=my_qos

The pretrain dataset file needs to be set under data.datasets and looks something like this:

mimic-echo-224px/files/p10/p10002221/s94106955/94106955_0001.mp4 0
mimic-echo-224px/files/p10/p10002221/s94106955/94106955_0006.mp4 0
mimic-echo-224px/files/p10/p10002221/s94106955/94106955_0007.mp4 0
mimic-echo-224px/files/p10/p10002221/s94106955/94106955_0008.mp4 0
mimic-echo-224px/files/p10/p10002221/s94106955/94106955_0009.mp4 0

Since we are doing self-supervised pre-training, all the video labels are set to zero. You can begin pretraining from any of the pre-trained V-JEPA models below:

We keep the pretraining configuration mostly the same as in V-JEPA 2, but adjust some of the sampling and augmentation parameters for echocardiography:

app: vjepa
nodes: 1
tasks_per_node: 8
cpus_per_task: 16
mem_per_gpu: 220G
folder: checkpoints/pretrain/mimic/vjepa2_vitl_224px_16f
data:
  dataset_type: VideoDataset
  datasets:
  - /home/sagemaker-user/user-default-efs/vjepa2/data/csv/mimic_annotations_s3.csv # 525k echocardiogram video clips (224px)
  datasets_weights:
  - 1.0
  batch_size: 128
  crop_size: 224              # 

If you are not training from scratch, set optimization.checkpoint to your downloaded checkpoint path. Make sure to scale your learning rates!

Probe-based evaluation consists in training an attentive probe on top of frozen V-JEPA 2 features. We provide training scripts for training your own probes, and checkpoints to run inference directly.

Classification Dataset Format

For classification, prepare a two-column CSV. It should be space-delimited, with the first column being the path to the MP4, and the second being your integer label.

data/echo_views_22k/19068955.mp4 5
data/echo_views_22k/19076133.mp4 7
data/echo_views_22k/19083831.mp4 2
data/echo_views_22k/19086809.mp4 2
data/echo_views_22k/19089161.mp4 5

Regression Dataset Format

For regression, we perform standard scaling (Z-score normalization) by fitting distribution parameters (mean and standard deviation) solely on the training data to prevent data leakage, then transform all data splits to a centered distribution with zero mean and unit variance to stabilize model optimization. You can use the following code:

import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
import pickle

# 1. Initialize the Scaler
# We use StandardScaler to center the data (Mean=0, Std=1)
scaler = StandardScaler()

# 2. Fit ONLY on the Training Set
# This prevents information leakage from the future/test set
train_values = train_clean['Value'].values.reshape(-1, 1)
scaler.fit(train_values)

print(f"Scaler fitted. Mean: {scaler.mean_[0]:.4f}, Std: {scaler.scale_[0]:.4f}")

# 3. Transform All Splits
# We create a new column 'norm_value' which the model will try to predict
train_clean['norm_value'] = scaler.transform(train_clean['Value'].values.reshape(-1, 1))
val_clean['norm_value']   = scaler.transform(val_clean['Value'].values.reshape(-1, 1))
test_clean['norm_value']  = scaler.transform(test_clean['Value'].values.reshape(-1, 1))

# 4. Save the Scaler
# CRITICAL: You need this file to convert predictions back to real LVEF % later
with open('lvef_scaler.pkl', 'wb') as f:
    pickle.dump(scaler, f)

# --- Verification ---
print("\nNormalization Check (Train Set):")
print(train_clean['norm_value'].describe())
# Expected: Mean ~ 0.0, Std ~ 1.0

print("\nExample Data:")
print(train_clean[['Value', 'norm_value']].head(3))

The resulting CSV for train, val, and test should look something like this:

data/echo_a4c_lvef/2230801.mp4 -3.4486802913030026
data/echo_a4c_lvef/3260170.mp4 -0.16931876118450664
data/echo_a4c_lvef/2758271.mp4 0.7278549632852218
data/echo_a4c_lvef/4291596.mp4 0.7278549632852218
data/echo_a4c_lvef/2350500.mp4 -0.9335615677497962
data/echo_a4c_lvef/2351242.mp4 -0.9335615677497962
data/echo_a4c_lvef/2761632.mp4 -0.9335615677497962
data/echo_a4c_lvef/2351284.mp4 -0.9335615677497962
data/echo_a4c_lvef/3257799.mp4 -0.9335615677497962
data/echo_a4c_lvef/2759135.mp4 0.8186436513424096

Make sure to make a note of the mean and standard deviation of your dataset (or save the scaling pickle), as we will need these for inference. If you are preparing a multi-view dataset, the dataset will have additional columns for your videos. The last column is reserved for the label.

Now, we can prepare the training config. Here is an example config for regression

app: vjepa  
cpus_per_task: 32  
folder: /path/to/experiments/lvef_regression
mem_per_gpu: 80G  
nodes: 1  
tasks_per_node: 8  
num_workers: 8  
  
eval_name: video_classification_frozen  # 

The settings are largely the same for classification, except we change task_type: regression to task_type: classification and we replace num_targets: 1 with num_classes: N.
For multi-view, set model_kwargs.module_name to evals.video_classification_frozen_multi.modelcustom.vit_encoder_multiclip. See configs/eval/vitg-384 for more examples.

Evaluations can be run either locally, or distributed via SLURM. (Running locally is useful for debugging and validation).
Use provided training configs under “Evaluation Attentive Probes”. These configs allow to train multiple probes in parallel with various optimization parameters.
Change filepaths as needed (e.g. folder, checkpoint, dataset_train, dataset_val) to match locations of data and downloaded checkpoints on your local filesystem.
Change # nodes and local batch size as needed to not exceed available GPU memory.

To run locally, specify the GPUs to use. For example, training an LVEF probe:

python -m evals.main --fname configs/eval/vitg-384/lvef/echojepa_lvef.yaml \
  --devices cuda:0 cuda:1
python -m evals.main_distributed \
  --fname configs/eval/vitg-384/lvef/echojepa_lvef.yaml  \
  --time 8600 \
  --account my_account --qos=my_qos

Additional configs can be found for other tasks, including configs/eval/vitg-384/view and configs/eval/vitg-384/rvsp. Each directory has ready-made configs for PanEcho, EchoPrime, VideoMAE, and EchoJEPA. All that needs to be modified are the checkpoint and dataset paths. There are additional configs for EchoJEPA-L under configs/eval/vitl.

To do inference with a trained probe, you can use the inference configs under configs/inference/vitg-384 or configs/inference/vitl. For example, performing inference using a trained LVEF probe:

python -m evals.main --fname configs/inference/vitl/lvef.yaml --devices cuda:0 cuda:1 cuda:2

For regression, these are the key parameters that need to be changed:

# --- CRITICAL INFERENCE SETTINGS ---
val_only: true       # 

For inference, we simply zero out the optimization parameters and set the num_epochs to 1. Batch size can be configured based on your GPU memory.

  optimization:
    # You can increase batch size for inference since gradients are off
    batch_size: 4
    
    # Zero out learning rates for safety (though val_only prevents updates anyway)
    multihead_kwargs:
    - final_lr: 0.0
      final_weight_decay: 0.0
      lr: 0.0
      start_lr: 0.0
      warmup: 0.0
      weight_decay: 0.0

    num_epochs: 1 # Run once
    use_bfloat16: true
    use_pos_embed: false

The other settings should be identical to your training config. See configs under configs/inference/vitg-384 or configs/inference/vitl for more examples.

.
├── app                                    # training loops
│   ├── vjepa                              #   video JEPA pre-training
│   ├── main_distributed.py                #   entrypoint for launch app on slurm cluster
│   └── main.py                            #   entrypoint for launch app locally on your machine
├── configs                                # config files with experiment params for training and evaluation
│   ├── train                              #   pretraining (phase 1), cooldown (phase 2), and action-conditioned training
│   └── eval                               #   frozen evaluations
├── evals                                  # evaluation loops training an attentive probe with frozen backbone...
│   ├── video_classification_frozen        #   single-view echocardiogram probes
│   ├── video_classification_frozen_multi  #   multi-view echocardiogram probes
│   ├── main_distributed.py                #   entrypoint for distributed evaluations
│   └── main.py                            #   entrypoint for locally-run evaluations
├── src                                    # the package
│   ├── datasets                           #   datasets, data loaders, ...
│   ├── models                             #   model definitions
│   ├── masks                              #   mask collators, masking utilities, ...
│   └── utils                              #   shared utilities
├── tests                                  # unit tests for some modules in `src`

The majority of V-JEPA 2 is licensed under MIT, however portions of the project are available under separate license terms:

src/datasets/utils/video/randaugment.py
src/datasets/utils/video/randerase.py
src/datasets/utils/worker_init_fn.py

are licensed under the Apache 2.0 license.

If you find this repository useful in your research, please consider giving a star ⭐ and a citation

@misc{munim2026echojepalatentpredictivefoundation,
      title={EchoJEPA: A Latent Predictive Foundation Model for Echocardiography}, 
      author={Alif Munim and Adibvafa Fallahpour and Teodora Szasz and Ahmadreza Attarpour and River Jiang and Brana Sooriyakanthan and Maala Sooriyakanthan and Heather Whitney and Jeremy Slivnick and Barry Rubin and Wendy Tsang and Bo Wang},
      year={2026},
      eprint={2602.02603},
      archivePrefix={arXiv},
      primaryClass={eess.IV},
      url={https://arxiv.org/abs/2602.02603}, 
}



Source link

Leave a Reply

Your email address will not be published. Required fields are marked *