sachin commited on
Commit
31e368b
·
1 Parent(s): 24d96ab
Files changed (6) hide show
  1. .gitignore +1 -0
  2. .vscode/settings.json +23 -0
  3. pyproject.toml +21 -0
  4. src/data.py +3 -1
  5. src/download.py +2 -2
  6. src/models.py +4 -3
.gitignore CHANGED
@@ -2,3 +2,4 @@
2
  .vscode/
3
  pyrightconfig.json
4
  *.jpg
 
 
2
  .vscode/
3
  pyrightconfig.json
4
  *.jpg
5
+ *.pyc
.vscode/settings.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "files.insertFinalNewline": true,
3
+ "jupyter.debugJustMyCode": false,
4
+ "editor.formatOnSave": true,
5
+ "editor.formatOnPaste": true,
6
+ "files.autoSave": "onFocusChange",
7
+ "editor.defaultFormatter": "ms-python.black-formatter",
8
+ "black-formatter.path": ["/opt/homebrew/bin/black"],
9
+ "black-formatter.args": ["--config", "./pyproject.toml"],
10
+ "black-formatter.cwd": "${workspaceFolder}",
11
+
12
+ "isort.check": true,
13
+ "python.analysis.typeCheckingMode": "basic",
14
+ "python.defaultInterpreterPath": "/opt/homebrew/bin/python3",
15
+ "[python]": {
16
+ "editor.defaultFormatter": "ms-python.black-formatter",
17
+ "editor.formatOnSave": true,
18
+ "editor.codeActionsOnSave": {
19
+ "source.organizeImports": "explicit"
20
+ },
21
+ },
22
+ "isort.args":["--profile", "black"],
23
+ }
pyproject.toml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.black]
2
+ line_length = 100
3
+
4
+ [tool.isort]
5
+ honor_noqa = true
6
+ line_length = 100
7
+ profile = "black"
8
+ verbose = false
9
+
10
+ known_first_party = [
11
+ # all folders you want to lump
12
+ "src",
13
+ ]
14
+
15
+ # Block below is google style import formatting https://pycqa.github.io/isort/docs/configuration/profiles.html
16
+ force_sort_within_sections = true
17
+ force_single_line = true
18
+ lexicographical = true
19
+ single_line_exclusions = ["typing"]
20
+ order_by_type = false
21
+ group_by_package = true
src/data.py CHANGED
@@ -5,7 +5,8 @@ from typing import Any
5
  import datasets
6
  from PIL import Image
7
  import torch
8
- from torch.utils.data import Dataset, DataLoader
 
9
  from torchvision import transforms
10
 
11
  from src import config
@@ -92,6 +93,7 @@ def get_dataset(
92
  if __name__ == "__main__":
93
  # do not want to do these imports in general
94
  import os
 
95
  from tqdm.auto import tqdm
96
 
97
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
 
5
  import datasets
6
  from PIL import Image
7
  import torch
8
+ from torch.utils.data import DataLoader
9
+ from torch.utils.data import Dataset
10
  from torchvision import transforms
11
 
12
  from src import config
 
93
  if __name__ == "__main__":
94
  # do not want to do these imports in general
95
  import os
96
+
97
  from tqdm.auto import tqdm
98
 
99
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
src/download.py CHANGED
@@ -1,11 +1,11 @@
 
1
  from io import BytesIO
2
  import pathlib
3
- from functools import partial
4
  from typing import Any
5
 
6
  import datasets
7
- from PIL import Image
8
  from loguru import logger
 
9
  import requests
10
  from tqdm.auto import tqdm
11
 
 
1
+ from functools import partial
2
  from io import BytesIO
3
  import pathlib
 
4
  from typing import Any
5
 
6
  import datasets
 
7
  from loguru import logger
8
+ from PIL import Image
9
  import requests
10
  from tqdm.auto import tqdm
11
 
src/models.py CHANGED
@@ -1,14 +1,15 @@
1
  from PIL import Image
2
- import transformers
3
  import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
-
7
  from transformers import PreTrainedModel
8
 
9
- from src.config import TinyCLIPConfig, TinyCLIPTextConfig, TinyCLIPVisionConfig
10
  from src import loss
11
  from src import vision_model
 
 
 
12
 
13
 
14
  class Projection(nn.Module):
 
1
  from PIL import Image
 
2
  import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
5
+ import transformers
6
  from transformers import PreTrainedModel
7
 
 
8
  from src import loss
9
  from src import vision_model
10
+ from src.config import TinyCLIPConfig
11
+ from src.config import TinyCLIPTextConfig
12
+ from src.config import TinyCLIPVisionConfig
13
 
14
 
15
  class Projection(nn.Module):