Spaces:
Runtime error
Runtime error
NimaBoscarino
commited on
Commit
·
c3ede35
1
Parent(s):
25eadae
Large rewrite, simplification, new UI
Browse files- .gitignore +3 -0
- README.md +1 -1
- app.py +83 -17
- fonts/DidactGothic-Regular.ttf +0 -0
- fonts/Inter-Regular.ttf +0 -0
- requirements.txt +1 -1
- substra_launcher.py +8 -4
- substra_template/Dockerfile +2 -30
- substra_template/README.md +10 -0
- substra_template/__init__.py +0 -0
- substra_template/mlflow-2.1.2.dev0-py3-none-any.whl +0 -3
- substra_template/mlflow_live_performances.py +0 -45
- substra_template/requirements.txt +0 -13
- substra_template/run.sh +0 -13
- substra_template/run_compute_plan.py +0 -40
- substra_template/substra_helpers/__init__.py +0 -0
- substra_template/substra_helpers/dataset.py +0 -29
- substra_template/substra_helpers/dataset_assets/description.md +0 -18
- substra_template/substra_helpers/dataset_assets/opener.py +0 -20
- substra_template/substra_helpers/model.py +0 -25
- substra_template/substra_helpers/substra_runner.py +0 -194
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
.idea/
|
2 |
+
.DS_Store
|
3 |
+
__pycache__/
|
README.md
CHANGED
@@ -4,7 +4,7 @@ emoji: 🌍
|
|
4 |
colorFrom: purple
|
5 |
colorTo: gray
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 3.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: gpl-3.0
|
|
|
4 |
colorFrom: purple
|
5 |
colorTo: gray
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 3.24.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: gpl-3.0
|
app.py
CHANGED
@@ -1,4 +1,11 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
|
4 |
theme = gr.themes.Default(primary_hue="blue").set(
|
@@ -7,9 +14,49 @@ theme = gr.themes.Default(primary_hue="blue").set(
|
|
7 |
)
|
8 |
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
demo = gr.Blocks(theme=theme, css="""\
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
.gradio-container {
|
12 |
-
width: 100
|
13 |
}
|
14 |
|
15 |
.margin-top {
|
@@ -26,19 +73,24 @@ demo = gr.Blocks(theme=theme, css="""\
|
|
26 |
}
|
27 |
|
28 |
.blue {
|
29 |
-
/**
|
30 |
background-image: url("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/substra-banner.png");
|
31 |
background-size: cover;
|
32 |
-
**/
|
33 |
-
background-color: #223fb3;
|
34 |
}
|
35 |
|
36 |
.blue p {
|
37 |
color: white !important;
|
38 |
}
|
39 |
|
|
|
|
|
|
|
|
|
40 |
.info-box {
|
41 |
background: transparent !important;
|
|
|
|
|
|
|
|
|
42 |
}
|
43 |
""")
|
44 |
|
@@ -49,7 +101,7 @@ with demo:
|
|
49 |
gr.Markdown("# Federated Learning with Substra")
|
50 |
with gr.Row():
|
51 |
with gr.Column(scale=1, elem_classes=["blue", "column"]):
|
52 |
-
gr.Markdown("Here you can run a quick simulation of Federated Learning
|
53 |
gr.Markdown("Check out the accompanying blog post to learn more.")
|
54 |
with gr.Box(elem_classes=["info-box"]):
|
55 |
gr.Markdown("""\
|
@@ -60,22 +112,23 @@ with demo:
|
|
60 |
with gr.Column(scale=3, elem_classes=["white", "column"]):
|
61 |
gr.Markdown("""\
|
62 |
Data scientists doing medical research often face a shortage of high quality and diverse data to \
|
63 |
-
effectively train models. This challenge can be overcome by securely allowing training on
|
64 |
-
data through
|
65 |
-
enables researchers to easily train ML models on remote data regardless of the
|
66 |
-
using or the data
|
67 |
""")
|
68 |
-
gr.Markdown("### Here we show an example of image data located in two different hospitals
|
69 |
gr.Markdown("""\
|
70 |
-
By playing with the distribution of data in the
|
71 |
the federated models compare with models trained on single datasets. The data used is from the \
|
72 |
-
Camelyon17 dataset, a commonly used benchmark in the medical world that comes from
|
73 |
-
The sample below shows normal cells on the
|
|
|
74 |
""")
|
75 |
gr.HTML("""
|
76 |
<img
|
77 |
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/substra-tumor.png"
|
78 |
-
style="
|
79 |
/>
|
80 |
""")
|
81 |
gr.Markdown("""\
|
@@ -87,8 +140,21 @@ with demo:
|
|
87 |
""")
|
88 |
|
89 |
with gr.Row(elem_classes=["margin-top"]):
|
90 |
-
gr.Slider(
|
91 |
-
|
92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
|
94 |
demo.launch()
|
|
|
1 |
import gradio as gr
|
2 |
+
import uuid
|
3 |
+
import asyncio
|
4 |
+
|
5 |
+
from substra_launcher import launch_substra_space
|
6 |
+
from huggingface_hub import HfApi
|
7 |
+
|
8 |
+
hf_api = HfApi()
|
9 |
|
10 |
|
11 |
theme = gr.themes.Default(primary_hue="blue").set(
|
|
|
14 |
)
|
15 |
|
16 |
|
17 |
+
async def launch_experiment(hospital_a, hospital_b):
|
18 |
+
experiment_id = str(uuid.uuid4())
|
19 |
+
|
20 |
+
asyncio.create_task(launch_substra_space(
|
21 |
+
hf_api=hf_api,
|
22 |
+
repo_id=experiment_id,
|
23 |
+
hospital_a=hospital_a,
|
24 |
+
hospital_b=hospital_b,
|
25 |
+
))
|
26 |
+
|
27 |
+
url = f"https://hf.space/NimaBoscarino/{experiment_id}"
|
28 |
+
|
29 |
+
return (
|
30 |
+
gr.Button.update(interactive=False),
|
31 |
+
gr.Markdown.update(
|
32 |
+
visible=True,
|
33 |
+
value=f"Your experiment is available at [hf.space/NimaBoscarino/{experiment_id}]({url})!"
|
34 |
+
)
|
35 |
+
)
|
36 |
+
|
37 |
+
|
38 |
demo = gr.Blocks(theme=theme, css="""\
|
39 |
+
@font-face {
|
40 |
+
font-family: "Didact Gothic";
|
41 |
+
src: url('https://huggingface.co/datasets/NimaBoscarino/assets/resolve/main/substra/DidactGothic-Regular.ttf') format('truetype');
|
42 |
+
}
|
43 |
+
|
44 |
+
@font-face {
|
45 |
+
font-family: "Inter";
|
46 |
+
src: url('https://huggingface.co/datasets/NimaBoscarino/assets/resolve/main/substra/Inter-Regular.ttf') format('truetype');
|
47 |
+
}
|
48 |
+
|
49 |
+
h1 {
|
50 |
+
font-family: "Didact Gothic";
|
51 |
+
font-size: 40px !important;
|
52 |
+
}
|
53 |
+
|
54 |
+
p {
|
55 |
+
font-family: "Inter";
|
56 |
+
}
|
57 |
+
|
58 |
.gradio-container {
|
59 |
+
min-width: 100% !important;
|
60 |
}
|
61 |
|
62 |
.margin-top {
|
|
|
73 |
}
|
74 |
|
75 |
.blue {
|
|
|
76 |
background-image: url("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/substra-banner.png");
|
77 |
background-size: cover;
|
|
|
|
|
78 |
}
|
79 |
|
80 |
.blue p {
|
81 |
color: white !important;
|
82 |
}
|
83 |
|
84 |
+
.blue strong {
|
85 |
+
color: white !important;
|
86 |
+
}
|
87 |
+
|
88 |
.info-box {
|
89 |
background: transparent !important;
|
90 |
+
border-radius: 20px !important;
|
91 |
+
border-color: white !important;
|
92 |
+
border-width: 4px !important;
|
93 |
+
padding: 20px !important;
|
94 |
}
|
95 |
""")
|
96 |
|
|
|
101 |
gr.Markdown("# Federated Learning with Substra")
|
102 |
with gr.Row():
|
103 |
with gr.Column(scale=1, elem_classes=["blue", "column"]):
|
104 |
+
gr.Markdown("Here you can run a **quick simulation of Federated Learning**.")
|
105 |
gr.Markdown("Check out the accompanying blog post to learn more.")
|
106 |
with gr.Box(elem_classes=["info-box"]):
|
107 |
gr.Markdown("""\
|
|
|
112 |
with gr.Column(scale=3, elem_classes=["white", "column"]):
|
113 |
gr.Markdown("""\
|
114 |
Data scientists doing medical research often face a shortage of high quality and diverse data to \
|
115 |
+
effectively train models. This challenge can be overcome by securely allowing training on protected \
|
116 |
+
data through Federated Learning. [Substra](https://docs.substra.org/) is a Python based Federated \
|
117 |
+
Learning software that enables researchers to easily train ML models on remote data regardless of the \
|
118 |
+
ML library they are using or the data type they are working with.
|
119 |
""")
|
120 |
+
gr.Markdown("### Here we show an example of image data located in **two different hospitals**.")
|
121 |
gr.Markdown("""\
|
122 |
+
By playing with the distribution of data in the two simulated hospitals, you'll be able to compare how \
|
123 |
the federated models compare with models trained on single datasets. The data used is from the \
|
124 |
+
Camelyon17 dataset, a commonly used benchmark in the medical world that comes from \
|
125 |
+
[this challenge](https://camelyon17.grand-challenge.org/). The sample below shows normal cells on the \
|
126 |
+
left compared with cancer cells on the right.
|
127 |
""")
|
128 |
gr.HTML("""
|
129 |
<img
|
130 |
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/substra-tumor.png"
|
131 |
+
style="height: 300px; margin: auto;"
|
132 |
/>
|
133 |
""")
|
134 |
gr.Markdown("""\
|
|
|
140 |
""")
|
141 |
|
142 |
with gr.Row(elem_classes=["margin-top"]):
|
143 |
+
hospital_a_slider = gr.Slider(
|
144 |
+
label="Percentage of positive samples in Hospital A",
|
145 |
+
value=50,
|
146 |
+
)
|
147 |
+
hospital_b_slider = gr.Slider(
|
148 |
+
label="Percentage of positive samples in Hospital B",
|
149 |
+
value=50,
|
150 |
+
)
|
151 |
+
launch_experiment_button = gr.Button(value="Launch Experiment 🚀")
|
152 |
+
visit_experiment_text = gr.Markdown(visible=False)
|
153 |
+
|
154 |
+
launch_experiment_button.click(
|
155 |
+
fn=launch_experiment,
|
156 |
+
inputs=[hospital_a_slider, hospital_b_slider],
|
157 |
+
outputs=[launch_experiment_button, visit_experiment_text]
|
158 |
+
)
|
159 |
|
160 |
demo.launch()
|
fonts/DidactGothic-Regular.ttf
ADDED
Binary file (181 kB). View file
|
|
fonts/Inter-Regular.ttf
ADDED
Binary file (748 kB). View file
|
|
requirements.txt
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
-
gradio
|
2 |
pytest
|
3 |
huggingface_hub
|
|
|
1 |
+
gradio
|
2 |
pytest
|
3 |
huggingface_hub
|
substra_launcher.py
CHANGED
@@ -1,7 +1,10 @@
|
|
1 |
from huggingface_hub import HfApi, RepoUrl
|
2 |
|
3 |
|
4 |
-
def launch_substra_space(
|
|
|
|
|
|
|
5 |
repo_id = "NimaBoscarino/" + repo_id
|
6 |
|
7 |
repo_url = hf_api.create_repo(
|
@@ -13,12 +16,13 @@ def launch_substra_space(hf_api: HfApi, num_hospitals: int, repo_id: str) -> Rep
|
|
13 |
hf_api.upload_folder(
|
14 |
repo_id=repo_id,
|
15 |
repo_type="space",
|
16 |
-
folder_path="substra_template/"
|
17 |
)
|
18 |
|
19 |
ENV_FILE = f"""\
|
20 |
-
|
21 |
-
|
|
|
22 |
|
23 |
hf_api.upload_file(
|
24 |
repo_id=repo_id,
|
|
|
1 |
from huggingface_hub import HfApi, RepoUrl
|
2 |
|
3 |
|
4 |
+
async def launch_substra_space(
|
5 |
+
hf_api: HfApi, repo_id: str,
|
6 |
+
hospital_a: int, hospital_b: int,
|
7 |
+
) -> RepoUrl:
|
8 |
repo_id = "NimaBoscarino/" + repo_id
|
9 |
|
10 |
repo_url = hf_api.create_repo(
|
|
|
16 |
hf_api.upload_folder(
|
17 |
repo_id=repo_id,
|
18 |
repo_type="space",
|
19 |
+
folder_path="./substra_template/"
|
20 |
)
|
21 |
|
22 |
ENV_FILE = f"""\
|
23 |
+
SUBSTRA_ORG1_DISTR={hospital_a / 100}
|
24 |
+
SUBSTRA_ORG2_DISTR={hospital_b / 100}\
|
25 |
+
"""
|
26 |
|
27 |
hf_api.upload_file(
|
28 |
repo_id=repo_id,
|
substra_template/Dockerfile
CHANGED
@@ -1,31 +1,3 @@
|
|
1 |
-
FROM
|
2 |
|
3 |
-
|
4 |
-
WORKDIR /code
|
5 |
-
|
6 |
-
# Copy the current directory contents into the container at /code
|
7 |
-
COPY ./requirements.txt /code/requirements.txt
|
8 |
-
COPY ./mlflow-2.1.2.dev0-py3-none-any.whl /code/mlflow-2.1.2.dev0-py3-none-any.whl
|
9 |
-
|
10 |
-
# Install requirements.txt
|
11 |
-
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
12 |
-
RUN chmod -R 777 /usr/local/lib/python3.10/site-packages/
|
13 |
-
|
14 |
-
# Set up a new user named "user" with user ID 1000
|
15 |
-
RUN useradd -m -u 1000 user
|
16 |
-
# Switch to the "user" user
|
17 |
-
USER user
|
18 |
-
# Set home to the user's home directory
|
19 |
-
ENV HOME=/home/user \
|
20 |
-
PATH=/home/user/.local/bin:$PATH
|
21 |
-
|
22 |
-
# Set the working directory to the user's home directory
|
23 |
-
WORKDIR $HOME/app
|
24 |
-
|
25 |
-
# Copy the current directory contents into the container at $HOME/app setting the owner to the user
|
26 |
-
COPY --chown=user . $HOME/app
|
27 |
-
|
28 |
-
RUN chmod -R 777 $HOME/app/
|
29 |
-
|
30 |
-
EXPOSE 7860
|
31 |
-
CMD ["bash", "run.sh"]
|
|
|
1 |
+
FROM nimaboscarino/substra-trainer:latest
|
2 |
|
3 |
+
CMD ["bash", "docker-run.sh"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
substra_template/README.md
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Substra Trainer
|
3 |
+
emoji: 🚀
|
4 |
+
colorFrom: red
|
5 |
+
colorTo: gray
|
6 |
+
sdk: docker
|
7 |
+
pinned: false
|
8 |
+
---
|
9 |
+
|
10 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
substra_template/__init__.py
DELETED
File without changes
|
substra_template/mlflow-2.1.2.dev0-py3-none-any.whl
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:e1f15359f38fab62f43a7a3d51f56c86c882a4cb1c3dcabeda6daf5dc47f1613
|
3 |
-
size 17638174
|
|
|
|
|
|
|
|
substra_template/mlflow_live_performances.py
DELETED
@@ -1,45 +0,0 @@
|
|
1 |
-
import pandas as pd
|
2 |
-
import json
|
3 |
-
from pathlib import Path
|
4 |
-
from mlflow import log_metric
|
5 |
-
import time
|
6 |
-
import os
|
7 |
-
from glob import glob
|
8 |
-
|
9 |
-
TIMEOUT = 240 # Number of seconds to stop the script after the last update of the json file
|
10 |
-
POLLING_FREQUENCY = 10 # Try to read the updates in the file every 10 seconds
|
11 |
-
|
12 |
-
# Wait for the file to be found
|
13 |
-
start = time.time()
|
14 |
-
while not len(glob(str(Path("local-worker") / "live_performances" / "*" / "performances.json"))) > 0:
|
15 |
-
time.sleep(POLLING_FREQUENCY)
|
16 |
-
if time.time() - start >= TIMEOUT:
|
17 |
-
raise TimeoutError("The performance file does not exist, maybe no test task has been executed yet.")
|
18 |
-
|
19 |
-
path_to_json = Path(glob(str(Path("local-worker") / "live_performances" / "*" / "performances.json"))[0])
|
20 |
-
|
21 |
-
logged_rows = []
|
22 |
-
last_update = time.time()
|
23 |
-
|
24 |
-
while (time.time() - last_update) <= TIMEOUT:
|
25 |
-
|
26 |
-
if last_update == os.path.getmtime(str(path_to_json)):
|
27 |
-
time.sleep(POLLING_FREQUENCY)
|
28 |
-
continue
|
29 |
-
|
30 |
-
last_update = os.path.getmtime(str(path_to_json))
|
31 |
-
|
32 |
-
time.sleep(1) # Waiting for the json to be fully written
|
33 |
-
dict_perf = json.load(path_to_json.open())
|
34 |
-
|
35 |
-
df = pd.DataFrame(dict_perf)
|
36 |
-
|
37 |
-
for _, row in df.iterrows():
|
38 |
-
if row["testtask_key"] in logged_rows:
|
39 |
-
continue
|
40 |
-
|
41 |
-
logged_rows.append(row["testtask_key"])
|
42 |
-
|
43 |
-
step = int(row["round_idx"]) if row["round_idx"] is not None else int(row["testtask_rank"])
|
44 |
-
|
45 |
-
log_metric(f"{row['metric_name']}_{row['worker']}", row["performance"], step)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
substra_template/requirements.txt
DELETED
@@ -1,13 +0,0 @@
|
|
1 |
-
gradio
|
2 |
-
substrafl
|
3 |
-
datasets
|
4 |
-
torch
|
5 |
-
torchvision
|
6 |
-
scikit-learn
|
7 |
-
numpy==1.23.0
|
8 |
-
Pillow
|
9 |
-
transformers
|
10 |
-
matplotlib
|
11 |
-
pandas
|
12 |
-
python-dotenv
|
13 |
-
./mlflow-2.1.2.dev0-py3-none-any.whl
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
substra_template/run.sh
DELETED
@@ -1,13 +0,0 @@
|
|
1 |
-
PYTHONPATH=$HOME/app python run_compute_plan.py &
|
2 |
-
PYTHONPATH=$HOME/app python mlflow_live_performances.py &
|
3 |
-
|
4 |
-
SITE_PACKAGES=$(python -c 'import site; print(site.getsitepackages()[0])')
|
5 |
-
|
6 |
-
# Fix for the UI code being embedded in an iframe
|
7 |
-
# Replace window.parent.location.origin with *
|
8 |
-
for i in $SITE_PACKAGES/mlflow/server/js/build/static/js/*.js; do
|
9 |
-
sed -i 's/window\.parent\.location\.origin)/"*")/' $i
|
10 |
-
sed 's/window.top?.location.href || window.location.href/window.location.href/g' -i $i
|
11 |
-
done
|
12 |
-
|
13 |
-
mlflow ui --port 7860 --host 0.0.0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
substra_template/run_compute_plan.py
DELETED
@@ -1,40 +0,0 @@
|
|
1 |
-
from substra_helpers.substra_runner import SubstraRunner, algo_generator
|
2 |
-
from substra_helpers.model import CNN
|
3 |
-
from substra_helpers.dataset import TorchDataset
|
4 |
-
from substrafl.strategies import FedAvg
|
5 |
-
|
6 |
-
import torch
|
7 |
-
|
8 |
-
from dotenv import load_dotenv
|
9 |
-
import os
|
10 |
-
load_dotenv()
|
11 |
-
|
12 |
-
NUM_CLIENTS = int(os.environ["SUBSTRA_NUM_HOSPITALS"])
|
13 |
-
|
14 |
-
seed = 42
|
15 |
-
torch.manual_seed(seed)
|
16 |
-
model = CNN()
|
17 |
-
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
18 |
-
criterion = torch.nn.CrossEntropyLoss()
|
19 |
-
|
20 |
-
runner = SubstraRunner(num_clients=NUM_CLIENTS)
|
21 |
-
runner.set_up_clients()
|
22 |
-
runner.prepare_data()
|
23 |
-
runner.register_data()
|
24 |
-
runner.register_metric()
|
25 |
-
|
26 |
-
runner.algorithm = algo_generator(
|
27 |
-
model=model,
|
28 |
-
criterion=criterion,
|
29 |
-
optimizer=optimizer,
|
30 |
-
index_generator=runner.index_generator,
|
31 |
-
dataset=TorchDataset,
|
32 |
-
seed=seed
|
33 |
-
)()
|
34 |
-
|
35 |
-
runner.strategy = FedAvg()
|
36 |
-
|
37 |
-
runner.set_aggregation()
|
38 |
-
runner.set_testing()
|
39 |
-
|
40 |
-
runner.run_compute_plan()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
substra_template/substra_helpers/__init__.py
DELETED
File without changes
|
substra_template/substra_helpers/dataset.py
DELETED
@@ -1,29 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from torch.utils import data
|
3 |
-
import torch.nn.functional as F
|
4 |
-
import numpy as np
|
5 |
-
|
6 |
-
|
7 |
-
class TorchDataset(data.Dataset):
|
8 |
-
def __init__(self, datasamples, is_inference: bool):
|
9 |
-
self.x = datasamples["image"]
|
10 |
-
self.y = datasamples["label"]
|
11 |
-
self.is_inference = is_inference
|
12 |
-
|
13 |
-
def __getitem__(self, idx):
|
14 |
-
|
15 |
-
if self.is_inference:
|
16 |
-
x = torch.FloatTensor(np.array(self.x[idx])[None, ...]) / 255
|
17 |
-
return x
|
18 |
-
|
19 |
-
else:
|
20 |
-
x = torch.FloatTensor(np.array(self.x[idx])[None, ...]) / 255
|
21 |
-
|
22 |
-
y = torch.tensor(self.y[idx]).type(torch.int64)
|
23 |
-
y = F.one_hot(y, 10)
|
24 |
-
y = y.type(torch.float32)
|
25 |
-
|
26 |
-
return x, y
|
27 |
-
|
28 |
-
def __len__(self):
|
29 |
-
return len(self.x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
substra_template/substra_helpers/dataset_assets/description.md
DELETED
@@ -1,18 +0,0 @@
|
|
1 |
-
# Mnist
|
2 |
-
|
3 |
-
This dataset is [THE MNIST DATABASE of handwritten digits](http://yann.lecun.com/exdb/mnist/).
|
4 |
-
|
5 |
-
The target is the number (0 -> 9) represented by the pixels.
|
6 |
-
|
7 |
-
## Data repartition
|
8 |
-
|
9 |
-
### Train and test
|
10 |
-
|
11 |
-
### Split data between organizations
|
12 |
-
|
13 |
-
## Opener usage
|
14 |
-
|
15 |
-
The opener exposes 2 methods:
|
16 |
-
|
17 |
-
- `get_data` returns a dictionary containing the images and the labels as numpy arrays
|
18 |
-
- `fake_data` returns a fake data sample of images and labels in a dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
substra_template/substra_helpers/dataset_assets/opener.py
DELETED
@@ -1,20 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
import substratools as tools
|
3 |
-
from datasets import load_from_disk
|
4 |
-
from transformers import ImageFeatureExtractionMixin
|
5 |
-
|
6 |
-
|
7 |
-
class MnistOpener(tools.Opener):
|
8 |
-
def fake_data(self, n_samples=None):
|
9 |
-
N_SAMPLES = n_samples if n_samples and n_samples <= 100 else 100
|
10 |
-
|
11 |
-
fake_images = np.random.randint(256, size=(N_SAMPLES, 28, 28))
|
12 |
-
|
13 |
-
fake_labels = np.random.randint(10, size=N_SAMPLES)
|
14 |
-
|
15 |
-
data = {"image": fake_images, "label": fake_labels}
|
16 |
-
|
17 |
-
return data
|
18 |
-
|
19 |
-
def get_data(self, folders):
|
20 |
-
return load_from_disk(folders[0])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
substra_template/substra_helpers/model.py
DELETED
@@ -1,25 +0,0 @@
|
|
1 |
-
from torch import nn
|
2 |
-
import torch.nn.functional as F
|
3 |
-
|
4 |
-
|
5 |
-
# TODO: Would be cool to use a simple Transformer model... then I could use the Trainer API 👀
|
6 |
-
class CNN(nn.Module):
|
7 |
-
def __init__(self):
|
8 |
-
super(CNN, self).__init__()
|
9 |
-
self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
|
10 |
-
self.conv2 = nn.Conv2d(32, 32, kernel_size=5)
|
11 |
-
self.conv3 = nn.Conv2d(32, 64, kernel_size=5)
|
12 |
-
self.fc1 = nn.Linear(3 * 3 * 64, 256)
|
13 |
-
self.fc2 = nn.Linear(256, 10)
|
14 |
-
|
15 |
-
def forward(self, x, eval=False):
|
16 |
-
x = F.relu(self.conv1(x))
|
17 |
-
x = F.relu(F.max_pool2d(self.conv2(x), 2))
|
18 |
-
x = F.dropout(x, p=0.5, training=not eval)
|
19 |
-
x = F.relu(F.max_pool2d(self.conv3(x), 2))
|
20 |
-
x = F.dropout(x, p=0.5, training=not eval)
|
21 |
-
x = x.view(-1, 3 * 3 * 64)
|
22 |
-
x = F.relu(self.fc1(x))
|
23 |
-
x = F.dropout(x, p=0.5, training=not eval)
|
24 |
-
x = self.fc2(x)
|
25 |
-
return F.log_softmax(x, dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
substra_template/substra_helpers/substra_runner.py
DELETED
@@ -1,194 +0,0 @@
|
|
1 |
-
import pathlib
|
2 |
-
import shutil
|
3 |
-
from typing import Optional, List
|
4 |
-
|
5 |
-
from substra import Client, BackendType
|
6 |
-
|
7 |
-
from substra.sdk.schemas import (
|
8 |
-
DatasetSpec,
|
9 |
-
Permissions,
|
10 |
-
DataSampleSpec
|
11 |
-
)
|
12 |
-
|
13 |
-
from substrafl.strategies import Strategy
|
14 |
-
from substrafl.dependency import Dependency
|
15 |
-
from substrafl.remote.register import add_metric
|
16 |
-
from substrafl.index_generator import NpIndexGenerator
|
17 |
-
from substrafl.algorithms.pytorch import TorchFedAvgAlgo
|
18 |
-
|
19 |
-
from substrafl.nodes import TrainDataNode, AggregationNode, TestDataNode
|
20 |
-
from substrafl.evaluation_strategy import EvaluationStrategy
|
21 |
-
|
22 |
-
from substrafl.experiment import execute_experiment
|
23 |
-
from substra.sdk.models import ComputePlan
|
24 |
-
|
25 |
-
from datasets import load_dataset, Dataset
|
26 |
-
from sklearn.metrics import accuracy_score
|
27 |
-
import numpy as np
|
28 |
-
|
29 |
-
import torch
|
30 |
-
|
31 |
-
|
32 |
-
class SubstraRunner:
|
33 |
-
def __init__(self, num_clients: int):
|
34 |
-
self.num_clients = num_clients
|
35 |
-
self.clients = {}
|
36 |
-
self.algo_provider: Optional[Client] = None
|
37 |
-
|
38 |
-
self.datasets: List[Dataset] = []
|
39 |
-
self.test_dataset: Optional[Dataset] = None
|
40 |
-
self.path = pathlib.Path(__file__).parent.resolve()
|
41 |
-
|
42 |
-
self.dataset_keys = {}
|
43 |
-
self.train_data_sample_keys = {}
|
44 |
-
self.test_data_sample_keys = {}
|
45 |
-
|
46 |
-
self.metric_key: Optional[str] = None
|
47 |
-
|
48 |
-
NUM_UPDATES = 100
|
49 |
-
BATCH_SIZE = 32
|
50 |
-
|
51 |
-
self.index_generator = NpIndexGenerator(
|
52 |
-
batch_size=BATCH_SIZE,
|
53 |
-
num_updates=NUM_UPDATES,
|
54 |
-
)
|
55 |
-
|
56 |
-
self.algorithm: Optional[TorchFedAvgAlgo] = None
|
57 |
-
self.strategy: Optional[Strategy] = None
|
58 |
-
|
59 |
-
self.aggregation_node: Optional[AggregationNode] = None
|
60 |
-
self.train_data_nodes = list()
|
61 |
-
self.test_data_nodes = list()
|
62 |
-
self.eval_strategy: Optional[EvaluationStrategy] = None
|
63 |
-
|
64 |
-
self.NUM_ROUNDS = 3
|
65 |
-
self.compute_plan: Optional[ComputePlan] = None
|
66 |
-
|
67 |
-
self.experiment_folder = self.path / "experiment_summaries"
|
68 |
-
|
69 |
-
def set_up_clients(self):
|
70 |
-
self.algo_provider = Client(backend_type=BackendType.LOCAL_SUBPROCESS)
|
71 |
-
|
72 |
-
self.clients = {
|
73 |
-
c.organization_info().organization_id: c
|
74 |
-
for c in [Client(backend_type=BackendType.LOCAL_SUBPROCESS) for _ in range(self.num_clients - 1)]
|
75 |
-
}
|
76 |
-
|
77 |
-
def prepare_data(self):
|
78 |
-
dataset = load_dataset("mnist", split="train").shuffle()
|
79 |
-
self.datasets = [dataset.shard(num_shards=self.num_clients - 1, index=i) for i in range(self.num_clients - 1)]
|
80 |
-
|
81 |
-
self.test_dataset = load_dataset("mnist", split="test")
|
82 |
-
|
83 |
-
data_path = self.path / "data"
|
84 |
-
if data_path.exists() and data_path.is_dir():
|
85 |
-
shutil.rmtree(data_path)
|
86 |
-
|
87 |
-
for i, client_id in enumerate(self.clients):
|
88 |
-
ds = self.datasets[i]
|
89 |
-
ds.save_to_disk(data_path / client_id / "train")
|
90 |
-
self.test_dataset.save_to_disk(data_path / client_id / "test")
|
91 |
-
|
92 |
-
def register_data(self):
|
93 |
-
for client_id, client in self.clients.items():
|
94 |
-
permissions_dataset = Permissions(public=False, authorized_ids=[
|
95 |
-
self.algo_provider.organization_info().organization_id
|
96 |
-
])
|
97 |
-
|
98 |
-
dataset = DatasetSpec(
|
99 |
-
name="MNIST",
|
100 |
-
type="npy",
|
101 |
-
data_opener=self.path / pathlib.Path("dataset_assets/opener.py"),
|
102 |
-
description=self.path / pathlib.Path("dataset_assets/description.md"),
|
103 |
-
permissions=permissions_dataset,
|
104 |
-
logs_permission=permissions_dataset,
|
105 |
-
)
|
106 |
-
self.dataset_keys[client_id] = client.add_dataset(dataset)
|
107 |
-
assert self.dataset_keys[client_id], "Missing dataset key"
|
108 |
-
|
109 |
-
self.train_data_sample_keys[client_id] = client.add_data_sample(DataSampleSpec(
|
110 |
-
data_manager_keys=[self.dataset_keys[client_id]],
|
111 |
-
path=self.path / "data" / client_id / "train",
|
112 |
-
))
|
113 |
-
|
114 |
-
data_sample = DataSampleSpec(
|
115 |
-
data_manager_keys=[self.dataset_keys[client_id]],
|
116 |
-
path=self.path / "data" / client_id / "test",
|
117 |
-
)
|
118 |
-
self.test_data_sample_keys[client_id] = client.add_data_sample(data_sample)
|
119 |
-
|
120 |
-
def register_metric(self):
|
121 |
-
permissions_metric = Permissions(
|
122 |
-
public=False,
|
123 |
-
authorized_ids=[
|
124 |
-
self.algo_provider.organization_info().organization_id
|
125 |
-
] + list(self.clients.keys())
|
126 |
-
)
|
127 |
-
|
128 |
-
metric_deps = Dependency(pypi_dependencies=["numpy==1.23.1", "scikit-learn==1.1.1"])
|
129 |
-
|
130 |
-
def accuracy(datasamples, predictions_path):
|
131 |
-
y_true = datasamples["label"]
|
132 |
-
y_pred = np.load(predictions_path)
|
133 |
-
|
134 |
-
return accuracy_score(y_true, np.argmax(y_pred, axis=1))
|
135 |
-
|
136 |
-
self.metric_key = add_metric(
|
137 |
-
client=self.algo_provider,
|
138 |
-
metric_function=accuracy,
|
139 |
-
permissions=permissions_metric,
|
140 |
-
dependencies=metric_deps,
|
141 |
-
)
|
142 |
-
|
143 |
-
def set_aggregation(self):
|
144 |
-
self.aggregation_node = AggregationNode(self.algo_provider.organization_info().organization_id)
|
145 |
-
|
146 |
-
for org_id in self.clients:
|
147 |
-
train_data_node = TrainDataNode(
|
148 |
-
organization_id=org_id,
|
149 |
-
data_manager_key=self.dataset_keys[org_id],
|
150 |
-
data_sample_keys=[self.train_data_sample_keys[org_id]],
|
151 |
-
)
|
152 |
-
self.train_data_nodes.append(train_data_node)
|
153 |
-
|
154 |
-
def set_testing(self):
|
155 |
-
for org_id in self.clients:
|
156 |
-
test_data_node = TestDataNode(
|
157 |
-
organization_id=org_id,
|
158 |
-
data_manager_key=self.dataset_keys[org_id],
|
159 |
-
test_data_sample_keys=[self.test_data_sample_keys[org_id]],
|
160 |
-
metric_keys=[self.metric_key],
|
161 |
-
)
|
162 |
-
self.test_data_nodes.append(test_data_node)
|
163 |
-
|
164 |
-
self.eval_strategy = EvaluationStrategy(test_data_nodes=self.test_data_nodes, rounds=1)
|
165 |
-
|
166 |
-
def run_compute_plan(self):
|
167 |
-
algo_deps = Dependency(pypi_dependencies=["numpy==1.23.1", "torch==1.11.0"])
|
168 |
-
|
169 |
-
self.compute_plan = execute_experiment(
|
170 |
-
client=self.algo_provider,
|
171 |
-
algo=self.algorithm,
|
172 |
-
strategy=self.strategy,
|
173 |
-
train_data_nodes=self.train_data_nodes,
|
174 |
-
evaluation_strategy=self.eval_strategy,
|
175 |
-
aggregation_node=self.aggregation_node,
|
176 |
-
num_rounds=self.NUM_ROUNDS,
|
177 |
-
experiment_folder=self.experiment_folder,
|
178 |
-
dependencies=algo_deps,
|
179 |
-
)
|
180 |
-
|
181 |
-
|
182 |
-
def algo_generator(model, criterion, optimizer, index_generator, dataset, seed):
|
183 |
-
class MyAlgo(TorchFedAvgAlgo):
|
184 |
-
def __init__(self):
|
185 |
-
super().__init__(
|
186 |
-
model=model,
|
187 |
-
criterion=criterion,
|
188 |
-
optimizer=optimizer,
|
189 |
-
index_generator=index_generator,
|
190 |
-
dataset=dataset,
|
191 |
-
seed=seed,
|
192 |
-
)
|
193 |
-
|
194 |
-
return MyAlgo
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|