Building upon multimodal fusion and interpretable learning, this study adapts and extends
methods for AF recurrence prediction. It specifically compares the impact of four
open-source large language models (LLaMA-7B, Phi2-2.7B, Mistral-7B, and MedGemma-27B) on
the representation of Holter ECG reports, echocardiography reports, and surgical records.
A convolutional neural network is employed in the structured feature branch for
representation learning and classification. Furthermore, a generative adversarial network
is introduced to augment categories, mitigating the imbalance caused by the scarcity of
recurrence samples. The dataset comprises multimodal information from the perioperative
period and follow-up, including 28 structured features and textual data from Holter ECG
reports, echocardiography reports, and surgical records. The structured features are:
Age, Gender, BMI, Systolic Blood Pressure, Diastolic Blood Pressure, AF Duration,
Hypertension, Coronary Artery Disease, Diabetes, CHA2DS2-VASc, HAS-BLED, AF type, LAD,
LVEF, HbA1C, FPG, TC, TG, HDL, LDL, Albumin, ALT, AST, ALP, Creatine, eGFR, use of class
I/III or class II antiarrhythmic drugs.
Preprocessing begins with systematic data cleaning on structured channels: for continuous
variables, a combined outlier detection method based on clinically plausible range
constraints and the IQR rule is used, with extreme outliers beyond the threshold
truncated at quantiles while preserving order information (Supplementary A5). Missing
values are handled using a multiple imputation strategy: continuous variables are
predicted and imputed using regression models constructed with multiple imputation
chained equations (MICE), with mean and variance adjustments to avoid shrinkage;
categorical variables are imputed using mode or conditional sampling under Bayesian
smoothed frequency encoding to preserve category co-occurrence relationships. To
standardize scales, continuous features are z-score normalized, while retaining scaling
parameters for external validation; categorical variables are subjected to target
leakage-free one-hot encoding or ordinal encoding (for clearly monotonic ordinal
features). All encoders are fitted within the training fold and transformed on the
validation fold and test set to prevent information leakage. For text channels,
lightweight cleaning and normalization are performed on dynamic electrocardiogram
reports, echocardiogram reports, and surgical records, including special symbol
unification, unit standardization, date and identifier de-identification, and medical
abbreviation expansion; subsequently, fragment-based sentence segmentation and keyword
localization are used to enhance key point density.
To implement LLM embedding + structured CNN late fusion, we construct four parallel text
encoders, each fine-tuned from a pre-trained LLM (LLaMA, Phi-2, Mistral, MedGemma).
Fine-tuning combines continued pre-training and instruction alignment: first,
domain-specific continued pre-training is performed on de-identified dynamic
electrocardiogram reports, echocardiogram reports, and surgical records from our
institution to improve clinical terminology coverage and syntactic robustness;
subsequently, supervised contrastive learning with a classification auxiliary objective
is used to moderately update the LLMs. To balance computational power and portability,
LoRA/QLoRA is used for low-rank adaptation, freezing most of the lower-layer weights and
opening up partial rank parameters in the mid-to-high-layer attention blocks and word
embeddings. Text representations are uniformly taken from the penultimate layer's
[CLS]-equivalent pooled vector and a token-attention-based weighted average, concatenated
to form a 1024-dimensional embedding, and then linearly projected to 256 dimensions to
match the representation space of the structured branch. For fair comparison, the four
LLMs independently train their respective text encoders and downstream fusion
classification heads, while the remaining training and evaluation procedures remain
consistent, resulting in four comparable multimodal models.
The structured branch uses ResNet1D as a 1D CNN backbone to learn local interactions and
hierarchical features from the 30-dimensional features.Specifically, the structured
vectors are stacked in a fixed order to form a "feature sequence" of length 30, which is
fed into a network containing three convolutional blocks: Conv1d (channels=32, kernel
size=3, stride=1) + BatchNorm + GELU + MaxPool, followed by a cascade of Conv1d(64, 3)
and Conv1d(128, 3). The pooling stride of each layer is controlled to cover different
receptive fields and extract cross-feature interactions. The convolutional output is then
subjected to global average pooling to obtain a 256-dimensional structured embedding,
which is enhanced with dropout and layer normalization to improve generalization.
Considering recurrence is a relatively small minority class, a conditional tabular
Generative Adversarial Network (GAN) is established within the training fold for the
structured branch to perform data augmentation and balance the class distribution. We
adopt a conditional WGAN-GP variant adapted to tabular data, where the generator is
conditioned on class labels and encoded categorical features to generate synthetic
positive samples consistent with the real joint distribution. The discriminator is
trained with a Lipschitz constraint and gradient penalty to improve stability. To prevent
the augmented data from introducing distribution drift and unreasonable feature
combinations, we apply a triple screening process after generation: first, density
filtering based on Mahalanobis distance to remove low-density outliers; second, hard
constraints based on clinical rules (physiological relationships between indicators and
consistency of scoring calculations); and third, an envelope screening of the positive
class manifold using a one-class SVM fitted only within the training fold. By performing
fold-wise augmentation within the training set, we achieve stratified sampling alignment,
while maintaining the natural distribution of the validation and test sets to avoid
evaluation bias.
The fusion stage follows a late fusion strategy with attention weighting. In each model
instance, the 256-dimensional embedding from the structured branch is concatenated with
the corresponding 256-dimensional text embedding from the LLM, forming a 512-dimensional
joint representation. This representation is then fed into a multi-head scaled
dot-product attention module for learnable cross-modal weighting, with the number of
heads set to 4 and the key/value dimension set to 64. A gating mechanism is incorporated
to suppress noisy text segments or weakly relevant structured components. The attention
output is mapped to the final binary classification logit via a two-layer feedforward
network (hidden dimension 256, activation GELU, dropout=0.2), and the cross-entropy loss
with label-balanced weights is used as the loss function. Optimization is performed using
AdamW (learning rate 2e-4, weight decay 0.01), with cosine annealing and warmup. Training
employs stratified k-fold cross-validation (k=5), with patient-level splitting to prevent
sample leakage. Within each fold, a validation set is used for early stopping and
hyperparameter selection. Evaluation metrics include accuracy, F1-score, sensitivity,
specificity, and area under the receiver operating characteristic curve (AUC-ROC), and
95% confidence intervals are reported. To ensure comparability across the four LLMs, all
non-text side components, optimizer settings, training epochs, and early stopping
criteria are kept consistent, with only the encoder being replaced and fine-tuned
individually on the text side. During inference, deterministic forwarding with a
temperature of 0 is used to obtain stable embeddings, and the maximum text length and
truncation strategy are fixed to avoid bias caused by differences in context length
between models.
For explainability analysis, we calculate SHAP values on the log-odds domain of the fused
model output to achieve global and individual explanations. For the structured branch, we
use DeepSHAP to approximate the marginal contribution of the convolutional pathway to the
output, reporting the global importance and interaction effects of each original clinical
feature. For the text branch, we combine attention-based token importance with SHAP's
text masking estimation to locate the descriptive words driving the prediction. Based on
the probability output of the optimal model, we determine the optimal threshold using the
Youden's J statistic to classify patients into high-risk and low-risk groups, followed by
survival analysis to compare outcome differences between the two groups.