Upload folder using huggingface_hub
Browse files- README.md +51 -0
- config.json +7 -0
- modeling_brainiac.py +51 -0
- pytorch_model.bin +3 -0
README.md
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# BrainIAC Model
|
2 |
+
|
3 |
+
This is the official implementation of the BrainIAC model, a 3D ResNet50-based architecture designed for brain image analysis.
|
4 |
+
|
5 |
+
## Model Description
|
6 |
+
|
7 |
+
BrainIAC is built on a modified ResNet50 architecture that processes 3D brain imaging data. The model has been adapted to handle volumetric inputs through 3D convolutions and produces feature vectors that capture relevant brain imaging characteristics.
|
8 |
+
|
9 |
+
## Model Architecture
|
10 |
+
- Base Architecture: ResNet50 (modified for 3D)
|
11 |
+
- Input: 3D brain volumes [batch_size, 1, D, H, W]
|
12 |
+
- Output: Feature vector of dimension 2048
|
13 |
+
- First layer: 3D convolution (1 channel input)
|
14 |
+
- Final layer: Identity (returns features directly)
|
15 |
+
|
16 |
+
## Usage
|
17 |
+
|
18 |
+
```python
|
19 |
+
from transformers import AutoModel
|
20 |
+
import torch
|
21 |
+
|
22 |
+
# Load model
|
23 |
+
model = AutoModel.from_pretrained("your-username/brainiac")
|
24 |
+
model.eval()
|
25 |
+
|
26 |
+
# Prepare your input tensor
|
27 |
+
# Adjust D, H, W according to your requirements
|
28 |
+
batch_size = 1
|
29 |
+
D, H, W = 16, 224, 224 # Example dimensions
|
30 |
+
input_tensor = torch.randn(batch_size, 1, D, H, W)
|
31 |
+
|
32 |
+
# Get features
|
33 |
+
with torch.no_grad():
|
34 |
+
features = model(input_tensor)
|
35 |
+
|
36 |
+
print(f"Output feature shape: {features.shape}") # Should be [batch_size, 2048]
|
37 |
+
```
|
38 |
+
|
39 |
+
## Requirements
|
40 |
+
```
|
41 |
+
torch>=1.9.0
|
42 |
+
monai
|
43 |
+
transformers
|
44 |
+
```
|
45 |
+
|
46 |
+
## Citation
|
47 |
+
If you use this model in your research, please cite:
|
48 |
+
[Add your citation information]
|
49 |
+
|
50 |
+
## License
|
51 |
+
[Add your license information]
|
config.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_attn_implementation_autoset": true,
|
3 |
+
"in_channels": 1,
|
4 |
+
"model_type": "brainiac",
|
5 |
+
"out_features": 2048,
|
6 |
+
"transformers_version": "4.46.3"
|
7 |
+
}
|
modeling_brainiac.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from typing import Optional
|
4 |
+
from transformers import PreTrainedModel, PretrainedConfig
|
5 |
+
from monai.networks.nets import resnet50
|
6 |
+
|
7 |
+
class BrainiacConfig(PretrainedConfig):
|
8 |
+
model_type = "brainiac"
|
9 |
+
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
in_channels: int = 1,
|
13 |
+
out_features: int = 2048, # ResNet50's default feature dimension
|
14 |
+
**kwargs
|
15 |
+
):
|
16 |
+
super().__init__(**kwargs)
|
17 |
+
self.in_channels = in_channels
|
18 |
+
self.out_features = out_features
|
19 |
+
|
20 |
+
class BrainiacModel(PreTrainedModel):
|
21 |
+
config_class = BrainiacConfig
|
22 |
+
base_model_prefix = "brainiac"
|
23 |
+
|
24 |
+
def __init__(self, config: BrainiacConfig):
|
25 |
+
super().__init__(config)
|
26 |
+
self.config = config
|
27 |
+
|
28 |
+
# Initialize ResNet50 from MONAI
|
29 |
+
self.resnet = resnet50(pretrained=False)
|
30 |
+
# Modify first conv layer for 3D input
|
31 |
+
self.resnet.conv1 = nn.Conv3d(
|
32 |
+
config.in_channels,
|
33 |
+
64,
|
34 |
+
kernel_size=7,
|
35 |
+
stride=2,
|
36 |
+
padding=3,
|
37 |
+
bias=False
|
38 |
+
)
|
39 |
+
# Replace final FC layer with Identity
|
40 |
+
self.resnet.fc = nn.Identity()
|
41 |
+
|
42 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
43 |
+
return self.resnet(x)
|
44 |
+
|
45 |
+
def _init_weights(self, module):
|
46 |
+
"""Initialize the weights"""
|
47 |
+
if isinstance(module, nn.Conv3d):
|
48 |
+
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
|
49 |
+
elif isinstance(module, (nn.BatchNorm3d, nn.GroupNorm)):
|
50 |
+
nn.init.constant_(module.weight, 1)
|
51 |
+
nn.init.constant_(module.bias, 0)
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f0ca899a9d4c40cec82a1c00950d82978fa409a57ec7088d80f86399ca1121d2
|
3 |
+
size 184954453
|