Tonic commited on
Commit
4299336
β€’
2 Parent(s): 1fcc6cb 13aacf1

add Nemo-Mistral-Minitron / Gradio 5

Browse files
Files changed (4) hide show
  1. README.md +5 -2
  2. app.py +32 -29
  3. globe.py +3 -5
  4. test.py +1662 -0
README.md CHANGED
@@ -4,11 +4,14 @@ emoji: πŸ πŸ€–πŸ‘ŒπŸ»
4
  colorFrom: blue
5
  colorTo: red
6
  sdk: gradio
7
- sdk_version: 4.39.0
8
  app_file: app.py
9
  pinned: true
10
  license: mit
11
  short_description: 'MiniNemo : High Performance With a SOTA Compression by Nvidia'
 
 
 
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
4
  colorFrom: blue
5
  colorTo: red
6
  sdk: gradio
7
+ sdk_version: 5.0.0b5
8
  app_file: app.py
9
  pinned: true
10
  license: mit
11
  short_description: 'MiniNemo : High Performance With a SOTA Compression by Nvidia'
12
+ short_description: State-of-the-Art Performance With a SOTA Compression
13
+ thumbnail: >-
14
+ https://cdn-uploads.huggingface.co/production/uploads/62a3bb1cd0d8c2c2169f0b88/tJn4I1ea2HlGIbiNqM-xw.png
15
  ---
16
 
17
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,17 +1,23 @@
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
4
- import json
5
  from globe import title, description, customtool , presentation1, presentation2, joinus
6
  import spaces
7
 
8
- model_path = "nvidia/Nemotron-Mini-4B-Instruct"
9
  tokenizer = AutoTokenizer.from_pretrained(model_path)
10
  model = AutoModelForCausalLM.from_pretrained(model_path)
11
 
12
- # Create a pipeline
13
- pipe = pipeline("text-generation", model=model_path)
14
- pipe.tokenizer = tokenizer # Assign tokenizer manually
 
 
 
 
 
 
 
15
 
16
  def create_prompt(system_message, user_message, tool_definition="", context=""):
17
  if tool_definition:
@@ -35,13 +41,10 @@ def create_prompt(system_message, user_message, tool_definition="", context=""):
35
  @spaces.GPU
36
  def generate_response(message, history, system_message, max_tokens, temperature, top_p, use_pipeline=False, tool_definition="", context=""):
37
  full_prompt = create_prompt(system_message, message, tool_definition, context)
38
-
39
  if use_pipeline:
40
- messages = [
41
- {"role": "system", "content": system_message},
42
- {"role": "user", "content": message},
43
- ]
44
- response = pipe(messages, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p)[0]['generated_text']
45
  else:
46
  tokenized_chat = tokenizer.apply_chat_template(
47
  [
@@ -55,7 +58,7 @@ def generate_response(message, history, system_message, max_tokens, temperature,
55
 
56
  with torch.no_grad():
57
  output_ids = model.generate(
58
- tokenized_chat,
59
  max_new_tokens=max_tokens,
60
  temperature=temperature,
61
  top_p=top_p,
@@ -63,30 +66,30 @@ def generate_response(message, history, system_message, max_tokens, temperature,
63
  )
64
 
65
  response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
66
-
67
  assistant_response = response.split("<extra_id_1>Assistant\n")[-1].strip()
68
-
69
  if tool_definition and "<toolcall>" in assistant_response:
70
  tool_call = assistant_response.split("<toolcall>")[1].split("</toolcall>")[0]
71
  assistant_response += f"\n\nTool Call: {tool_call}\n\nNote: This is a simulated tool call. In a real scenario, the tool would be executed and its output would be used to generate a final response."
72
-
73
  return assistant_response
74
 
75
  with gr.Blocks() as demo:
76
  with gr.Row():
77
  gr.Markdown(title)
78
  with gr.Row():
79
- gr.Markdown(description)
80
  with gr.Row():
81
- with gr.Group():
82
- gr.Markdown(presentation1)
83
- with gr.Group():
84
- gr.Markdown(presentation2)
85
- with gr.Row():
86
- gr.Markdown(joinus)
87
  with gr.Row():
88
  with gr.Column(scale=3):
89
- chatbot = gr.Chatbot(label="πŸ€–Nemotron-Mini", height=400)
90
  msg = gr.Textbox(label="User Input", placeholder="Ask a question or request a task...")
91
  with gr.Accordion(label="πŸ§ͺAdvanced Settings", open=False):
92
  system_message = gr.Textbox(
@@ -103,12 +106,12 @@ with gr.Blocks() as demo:
103
  max_tokens = gr.Slider(minimum=1, maximum=1024, value=256, step=1, label="Max Tokens")
104
  temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature")
105
  top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p")
106
- use_pipeline = gr.Checkbox(label="UseπŸ€·πŸ»β€β™‚οΈPipeline", value=False)
107
- use_tool = gr.Checkbox(label="Use FunctionπŸ€–Calling", value=False)
108
  with gr.Column(visible=False) as tool_options:
109
  tool_definition = gr.Code(
110
- label="πŸ€–Tool Definition (JSON)",
111
- value=customtool,
112
  lines=15,
113
  language="json"
114
  )
@@ -116,7 +119,6 @@ with gr.Blocks() as demo:
116
  clear = gr.Button("Clear")
117
  send = gr.Button("Send")
118
 
119
-
120
  def user(user_message, history):
121
  return "", history + [[user_message, None]]
122
 
@@ -141,4 +143,5 @@ with gr.Blocks() as demo:
141
  )
142
 
143
  if __name__ == "__main__":
144
- demo.launch()
 
 
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
 
4
  from globe import title, description, customtool , presentation1, presentation2, joinus
5
  import spaces
6
 
7
+ model_path = "nvidia/Mistral-NeMo-Minitron-8B-Instruct"
8
  tokenizer = AutoTokenizer.from_pretrained(model_path)
9
  model = AutoModelForCausalLM.from_pretrained(model_path)
10
 
11
+ # Extract config info from model's configuration
12
+ config_info = model.config
13
+
14
+ # Create a Markdown string to display the complete model configuration information
15
+ model_info_md = "### Model Configuration: Mistral-NeMo-Minitron-8B-Instruct\n\n"
16
+ for key, value in config_info.to_dict().items():
17
+ model_info_md += f"- **{key.replace('_', ' ').capitalize()}**: {value}\n"
18
+
19
+ pipe = pipeline("text-generation", model=model)
20
+ pipe.tokenizer = tokenizer
21
 
22
  def create_prompt(system_message, user_message, tool_definition="", context=""):
23
  if tool_definition:
 
41
  @spaces.GPU
42
  def generate_response(message, history, system_message, max_tokens, temperature, top_p, use_pipeline=False, tool_definition="", context=""):
43
  full_prompt = create_prompt(system_message, message, tool_definition, context)
44
+
45
  if use_pipeline:
46
+ prompt = [{"role": "system", "content": system_message}, {"role": "user", "content": message}]
47
+ response = pipe(prompt, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, stop_strings=["<extra_id_1>"])[0]['generated_text']
 
 
 
48
  else:
49
  tokenized_chat = tokenizer.apply_chat_template(
50
  [
 
58
 
59
  with torch.no_grad():
60
  output_ids = model.generate(
61
+ tokenized_chat['input_ids'],
62
  max_new_tokens=max_tokens,
63
  temperature=temperature,
64
  top_p=top_p,
 
66
  )
67
 
68
  response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
69
+
70
  assistant_response = response.split("<extra_id_1>Assistant\n")[-1].strip()
71
+
72
  if tool_definition and "<toolcall>" in assistant_response:
73
  tool_call = assistant_response.split("<toolcall>")[1].split("</toolcall>")[0]
74
  assistant_response += f"\n\nTool Call: {tool_call}\n\nNote: This is a simulated tool call. In a real scenario, the tool would be executed and its output would be used to generate a final response."
75
+
76
  return assistant_response
77
 
78
  with gr.Blocks() as demo:
79
  with gr.Row():
80
  gr.Markdown(title)
81
  with gr.Row():
82
+ gr.Markdown(description)
83
  with gr.Row():
84
+ with gr.Column(scale=1):
85
+ with gr.Group():
86
+ gr.Markdown(presentation1)
87
+ with gr.Column(scale=1):
88
+ with gr.Group():
89
+ gr.Markdown(model_info_md)
90
  with gr.Row():
91
  with gr.Column(scale=3):
92
+ chatbot = gr.Chatbot(label="πŸ€– Mistral-NeMo", height=400)
93
  msg = gr.Textbox(label="User Input", placeholder="Ask a question or request a task...")
94
  with gr.Accordion(label="πŸ§ͺAdvanced Settings", open=False):
95
  system_message = gr.Textbox(
 
106
  max_tokens = gr.Slider(minimum=1, maximum=1024, value=256, step=1, label="Max Tokens")
107
  temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature")
108
  top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p")
109
+ use_pipeline = gr.Checkbox(label="Use Pipeline", value=False)
110
+ use_tool = gr.Checkbox(label="Use Function Calling", value=False)
111
  with gr.Column(visible=False) as tool_options:
112
  tool_definition = gr.Code(
113
+ label="Tool Definition (JSON)",
114
+ value="{}",
115
  lines=15,
116
  language="json"
117
  )
 
119
  clear = gr.Button("Clear")
120
  send = gr.Button("Send")
121
 
 
122
  def user(user_message, history):
123
  return "", history + [[user_message, None]]
124
 
 
143
  )
144
 
145
  if __name__ == "__main__":
146
+ demo.queue
147
+ demo.launch()
globe.py CHANGED
@@ -3,16 +3,14 @@ joinus = """
3
  🌟TeamTonic🌟 is always making cool demos! Join our active builder's πŸ› οΈcommunity πŸ‘» [![Join us on Discord](https://img.shields.io/discord/1109943800132010065?label=Discord&logo=discord&style=flat-square)](https://discord.gg/qdfnvSPcqP) On πŸ€—Huggingface:[MultiTransformer](https://huggingface.co/MultiTransformer) On 🌐Github: [Tonic-AI](https://github.com/tonic-ai) & contribute to🌟 [Build Tonic](https://git.tonic-ai.com/contribute)πŸ€—Big thanks to Yuvi Sharma and all the folks at huggingface for the community grant πŸ€—
4
  """
5
 
6
- title = """# πŸ™‹πŸ»β€β™‚οΈWelcome to Tonic's πŸ€– Nemotron-Mini-4B Demo πŸš€"""
7
 
8
- description = """πŸ€–Nemotron-Mini-4B-Instruct is a model for generating responses for roleplaying, retrieval augmented generation, and function calling. It is a small language model (SLM) optimized through distillation, pruning and quantization for speed and on-device deployment. It is a fine-tuned version of [nvidia/Minitron-4B-Base](https://huggingface.co/nvidia/Minitron-4B-Base), which was pruned and distilled from [Nemotron-4 15B](https://arxiv.org/abs/2402.16819) using [our LLM compression technique](https://arxiv.org/abs/2407.14679). This instruct model is optimized for roleplay, RAG QA, and function calling in English. It supports a context length of 4,096 tokens. This model is ready for commercial use.
9
  """
10
 
11
  presentation1 = """Try this model on [build.nvidia.com](https://build.nvidia.com/nvidia/nemotron-mini-4b-instruct).
12
 
13
- **Model Developer:** NVIDIA
14
-
15
- **Model Dates:** πŸ€–Nemotron-Mini-4B-Instruct was trained between February 2024 and Aug 2024.
16
 
17
  ### License
18
 
 
3
  🌟TeamTonic🌟 is always making cool demos! Join our active builder's πŸ› οΈcommunity πŸ‘» [![Join us on Discord](https://img.shields.io/discord/1109943800132010065?label=Discord&logo=discord&style=flat-square)](https://discord.gg/qdfnvSPcqP) On πŸ€—Huggingface:[MultiTransformer](https://huggingface.co/MultiTransformer) On 🌐Github: [Tonic-AI](https://github.com/tonic-ai) & contribute to🌟 [Build Tonic](https://git.tonic-ai.com/contribute)πŸ€—Big thanks to Yuvi Sharma and all the folks at huggingface for the community grant πŸ€—
4
  """
5
 
6
+ title = """# πŸ™‹πŸ»β€β™‚οΈWelcome to Tonic's πŸ€– Mistral-NeMo-Minitron Demo πŸš€"""
7
 
8
+ description = """nvidia/πŸ€–Mistral-NeMo-Minitron-8B-Instruct is a model for generating responses for various text-generation tasks including roleplaying, retrieval augmented generation, and function calling.
9
  """
10
 
11
  presentation1 = """Try this model on [build.nvidia.com](https://build.nvidia.com/nvidia/nemotron-mini-4b-instruct).
12
 
13
+ Mistral-NeMo-Minitron-8B-Instruct is a model for generating responses for various text-generation tasks including roleplaying, retrieval augmented generation, and function calling. It is a fine-tuned version of [nvidia/Mistral-NeMo-Minitron-8B-Base](https://huggingface.co/nvidia/Mistral-NeMo-Minitron-8B-Base), which was pruned and distilled from [Mistral-NeMo 12B](https://huggingface.co/nvidia/Mistral-NeMo-12B-Base) using [our LLM compression technique](https://arxiv.org/abs/2407.14679). The model was trained using a multi-stage SFT and preference-based alignment technique with [NeMo Aligner](https://github.com/NVIDIA/NeMo-Aligner). For details on the alignment technique, please refer to the [Nemotron-4 340B Technical Report](https://arxiv.org/abs/2406.11704).
 
 
14
 
15
  ### License
16
 
test.py ADDED
@@ -0,0 +1,1662 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import warnings
3
+ from typing import Optional, Tuple
4
+
5
+ import torch
6
+ from torch import Tensor
7
+ from .linear import NonDynamicallyQuantizableLinear
8
+ from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
9
+ from torch.nn.parameter import Parameter
10
+ from .module import Module
11
+ from .. import functional as F
12
+
13
+ __all__ = ['Threshold', 'ReLU', 'RReLU', 'Hardtanh', 'ReLU6', 'Sigmoid', 'Hardsigmoid', 'Tanh',
14
+ 'SiLU', 'Mish', 'Hardswish', 'ELU', 'CELU', 'SELU', 'GLU', 'GELU', 'Hardshrink', 'LeakyReLU',
15
+ 'LogSigmoid', 'Softplus', 'Softshrink', 'MultiheadAttention', 'PReLU', 'Softsign', 'Tanhshrink',
16
+ 'Softmin', 'Softmax', 'Softmax2d', 'LogSoftmax']
17
+
18
+
19
+ [docs]class Threshold(Module):
20
+ r"""Thresholds each element of the input Tensor.
21
+
22
+ Threshold is defined as:
23
+
24
+ .. math::
25
+ y =
26
+ \begin{cases}
27
+ x, &\text{ if } x > \text{threshold} \\
28
+ \text{value}, &\text{ otherwise }
29
+ \end{cases}
30
+
31
+ Args:
32
+ threshold: The value to threshold at
33
+ value: The value to replace with
34
+ inplace: can optionally do the operation in-place. Default: ``False``
35
+
36
+ Shape:
37
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
38
+ - Output: :math:`(*)`, same shape as the input.
39
+
40
+ Examples::
41
+
42
+ >>> m = nn.Threshold(0.1, 20)
43
+ >>> input = torch.randn(2)
44
+ >>> output = m(input)
45
+ """
46
+
47
+ __constants__ = ['threshold', 'value', 'inplace']
48
+
49
+ threshold: float
50
+ value: float
51
+ inplace: bool
52
+
53
+ def __init__(self, threshold: float, value: float, inplace: bool = False) -> None:
54
+ super().__init__()
55
+ self.threshold = threshold
56
+ self.value = value
57
+ self.inplace = inplace
58
+ # TODO: check in THNN (if inplace == True, then assert value <= threshold)
59
+
60
+ def forward(self, input: Tensor) -> Tensor:
61
+ return F.threshold(input, self.threshold, self.value, self.inplace)
62
+
63
+ def extra_repr(self):
64
+ inplace_str = ', inplace=True' if self.inplace else ''
65
+ return f'threshold={self.threshold}, value={self.value}{inplace_str}'
66
+
67
+
68
+
69
+ [docs]class ReLU(Module):
70
+ r"""Applies the rectified linear unit function element-wise.
71
+
72
+ :math:`\text{ReLU}(x) = (x)^+ = \max(0, x)`
73
+
74
+ Args:
75
+ inplace: can optionally do the operation in-place. Default: ``False``
76
+
77
+ Shape:
78
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
79
+ - Output: :math:`(*)`, same shape as the input.
80
+
81
+ .. image:: ../scripts/activation_images/ReLU.png
82
+
83
+ Examples::
84
+
85
+ >>> m = nn.ReLU()
86
+ >>> input = torch.randn(2)
87
+ >>> output = m(input)
88
+
89
+
90
+ An implementation of CReLU - https://arxiv.org/abs/1603.05201
91
+
92
+ >>> m = nn.ReLU()
93
+ >>> input = torch.randn(2).unsqueeze(0)
94
+ >>> output = torch.cat((m(input), m(-input)))
95
+ """
96
+
97
+ __constants__ = ['inplace']
98
+ inplace: bool
99
+
100
+ def __init__(self, inplace: bool = False):
101
+ super().__init__()
102
+ self.inplace = inplace
103
+
104
+ def forward(self, input: Tensor) -> Tensor:
105
+ return F.relu(input, inplace=self.inplace)
106
+
107
+ def extra_repr(self) -> str:
108
+ inplace_str = 'inplace=True' if self.inplace else ''
109
+ return inplace_str
110
+
111
+
112
+
113
+ [docs]class RReLU(Module):
114
+ r"""Applies the randomized leaky rectified linear unit function, element-wise.
115
+
116
+ Method described in the paper:
117
+ `Empirical Evaluation of Rectified Activations in Convolutional Network <https://arxiv.org/abs/1505.00853>`_.
118
+
119
+ The function is defined as:
120
+
121
+ .. math::
122
+ \text{RReLU}(x) =
123
+ \begin{cases}
124
+ x & \text{if } x \geq 0 \\
125
+ ax & \text{ otherwise }
126
+ \end{cases}
127
+
128
+ where :math:`a` is randomly sampled from uniform distribution
129
+ :math:`\mathcal{U}(\text{lower}, \text{upper})` during training while during
130
+ evaluation :math:`a` is fixed with :math:`a = \frac{\text{lower} + \text{upper}}{2}`.
131
+
132
+ Args:
133
+ lower: lower bound of the uniform distribution. Default: :math:`\frac{1}{8}`
134
+ upper: upper bound of the uniform distribution. Default: :math:`\frac{1}{3}`
135
+ inplace: can optionally do the operation in-place. Default: ``False``
136
+
137
+ Shape:
138
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
139
+ - Output: :math:`(*)`, same shape as the input.
140
+
141
+ .. image:: ../scripts/activation_images/RReLU.png
142
+
143
+ Examples::
144
+
145
+ >>> m = nn.RReLU(0.1, 0.3)
146
+ >>> input = torch.randn(2)
147
+ >>> output = m(input)
148
+
149
+ """
150
+
151
+ __constants__ = ['lower', 'upper', 'inplace']
152
+
153
+ lower: float
154
+ upper: float
155
+ inplace: bool
156
+
157
+ def __init__(
158
+ self,
159
+ lower: float = 1. / 8,
160
+ upper: float = 1. / 3,
161
+ inplace: bool = False
162
+ ):
163
+ super().__init__()
164
+ self.lower = lower
165
+ self.upper = upper
166
+ self.inplace = inplace
167
+
168
+ def forward(self, input: Tensor) -> Tensor:
169
+ return F.rrelu(input, self.lower, self.upper, self.training, self.inplace)
170
+
171
+ def extra_repr(self):
172
+ inplace_str = ', inplace=True' if self.inplace else ''
173
+ return f'lower={self.lower}, upper={self.upper}{inplace_str}'
174
+
175
+
176
+
177
+ [docs]class Hardtanh(Module):
178
+ r"""Applies the HardTanh function element-wise.
179
+
180
+ HardTanh is defined as:
181
+
182
+ .. math::
183
+ \text{HardTanh}(x) = \begin{cases}
184
+ \text{max\_val} & \text{ if } x > \text{ max\_val } \\
185
+ \text{min\_val} & \text{ if } x < \text{ min\_val } \\
186
+ x & \text{ otherwise } \\
187
+ \end{cases}
188
+
189
+ Args:
190
+ min_val: minimum value of the linear region range. Default: -1
191
+ max_val: maximum value of the linear region range. Default: 1
192
+ inplace: can optionally do the operation in-place. Default: ``False``
193
+
194
+ Keyword arguments :attr:`min_value` and :attr:`max_value`
195
+ have been deprecated in favor of :attr:`min_val` and :attr:`max_val`.
196
+
197
+ Shape:
198
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
199
+ - Output: :math:`(*)`, same shape as the input.
200
+
201
+ .. image:: ../scripts/activation_images/Hardtanh.png
202
+
203
+ Examples::
204
+
205
+ >>> m = nn.Hardtanh(-2, 2)
206
+ >>> input = torch.randn(2)
207
+ >>> output = m(input)
208
+ """
209
+
210
+ __constants__ = ['min_val', 'max_val', 'inplace']
211
+
212
+ min_val: float
213
+ max_val: float
214
+ inplace: bool
215
+
216
+ def __init__(
217
+ self,
218
+ min_val: float = -1.,
219
+ max_val: float = 1.,
220
+ inplace: bool = False,
221
+ min_value: Optional[float] = None,
222
+ max_value: Optional[float] = None
223
+ ) -> None:
224
+ super().__init__()
225
+ if min_value is not None:
226
+ warnings.warn(
227
+ "keyword argument `min_value` is deprecated and rename to `min_val`",
228
+ FutureWarning,
229
+ stacklevel=2,
230
+ )
231
+ min_val = min_value
232
+ if max_value is not None:
233
+ warnings.warn(
234
+ "keyword argument `max_value` is deprecated and rename to `max_val`",
235
+ FutureWarning,
236
+ stacklevel=2,
237
+ )
238
+ max_val = max_value
239
+
240
+ self.min_val = min_val
241
+ self.max_val = max_val
242
+ self.inplace = inplace
243
+ assert self.max_val > self.min_val
244
+
245
+ def forward(self, input: Tensor) -> Tensor:
246
+ return F.hardtanh(input, self.min_val, self.max_val, self.inplace)
247
+
248
+ def extra_repr(self) -> str:
249
+ inplace_str = ', inplace=True' if self.inplace else ''
250
+ return f'min_val={self.min_val}, max_val={self.max_val}{inplace_str}'
251
+
252
+
253
+
254
+ [docs]class ReLU6(Hardtanh):
255
+ r"""Applies the ReLU6 function element-wise.
256
+
257
+ .. math::
258
+ \text{ReLU6}(x) = \min(\max(0,x), 6)
259
+
260
+ Args:
261
+ inplace: can optionally do the operation in-place. Default: ``False``
262
+
263
+ Shape:
264
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
265
+ - Output: :math:`(*)`, same shape as the input.
266
+
267
+ .. image:: ../scripts/activation_images/ReLU6.png
268
+
269
+ Examples::
270
+
271
+ >>> m = nn.ReLU6()
272
+ >>> input = torch.randn(2)
273
+ >>> output = m(input)
274
+ """
275
+
276
+ def __init__(self, inplace: bool = False):
277
+ super().__init__(0., 6., inplace)
278
+
279
+ def extra_repr(self) -> str:
280
+ inplace_str = 'inplace=True' if self.inplace else ''
281
+ return inplace_str
282
+
283
+
284
+
285
+ [docs]class Sigmoid(Module):
286
+ r"""Applies the Sigmoid function element-wise.
287
+
288
+ .. math::
289
+ \text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)}
290
+
291
+
292
+ Shape:
293
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
294
+ - Output: :math:`(*)`, same shape as the input.
295
+
296
+ .. image:: ../scripts/activation_images/Sigmoid.png
297
+
298
+ Examples::
299
+
300
+ >>> m = nn.Sigmoid()
301
+ >>> input = torch.randn(2)
302
+ >>> output = m(input)
303
+ """
304
+
305
+ def forward(self, input: Tensor) -> Tensor:
306
+ return torch.sigmoid(input)
307
+
308
+
309
+
310
+ [docs]class Hardsigmoid(Module):
311
+ r"""Applies the Hardsigmoid function element-wise.
312
+
313
+ Hardsigmoid is defined as:
314
+
315
+ .. math::
316
+ \text{Hardsigmoid}(x) = \begin{cases}
317
+ 0 & \text{if~} x \le -3, \\
318
+ 1 & \text{if~} x \ge +3, \\
319
+ x / 6 + 1 / 2 & \text{otherwise}
320
+ \end{cases}
321
+
322
+ Args:
323
+ inplace: can optionally do the operation in-place. Default: ``False``
324
+
325
+ Shape:
326
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
327
+ - Output: :math:`(*)`, same shape as the input.
328
+
329
+ .. image:: ../scripts/activation_images/Hardsigmoid.png
330
+
331
+ Examples::
332
+
333
+ >>> m = nn.Hardsigmoid()
334
+ >>> input = torch.randn(2)
335
+ >>> output = m(input)
336
+ """
337
+
338
+ __constants__ = ['inplace']
339
+
340
+ inplace: bool
341
+
342
+ def __init__(self, inplace : bool = False) -> None:
343
+ super().__init__()
344
+ self.inplace = inplace
345
+
346
+ def forward(self, input: Tensor) -> Tensor:
347
+ return F.hardsigmoid(input, self.inplace)
348
+
349
+
350
+
351
+ [docs]class Tanh(Module):
352
+ r"""Applies the Hyperbolic Tangent (Tanh) function element-wise.
353
+
354
+ Tanh is defined as:
355
+
356
+ .. math::
357
+ \text{Tanh}(x) = \tanh(x) = \frac{\exp(x) - \exp(-x)} {\exp(x) + \exp(-x)}
358
+
359
+ Shape:
360
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
361
+ - Output: :math:`(*)`, same shape as the input.
362
+
363
+ .. image:: ../scripts/activation_images/Tanh.png
364
+
365
+ Examples::
366
+
367
+ >>> m = nn.Tanh()
368
+ >>> input = torch.randn(2)
369
+ >>> output = m(input)
370
+ """
371
+
372
+ def forward(self, input: Tensor) -> Tensor:
373
+ return torch.tanh(input)
374
+
375
+
376
+ [docs]class SiLU(Module):
377
+ r"""Applies the Sigmoid Linear Unit (SiLU) function, element-wise.
378
+
379
+ The SiLU function is also known as the swish function.
380
+
381
+ .. math::
382
+ \text{silu}(x) = x * \sigma(x), \text{where } \sigma(x) \text{ is the logistic sigmoid.}
383
+
384
+ .. note::
385
+ See `Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_
386
+ where the SiLU (Sigmoid Linear Unit) was originally coined, and see
387
+ `Sigmoid-Weighted Linear Units for Neural Network Function Approximation
388
+ in Reinforcement Learning <https://arxiv.org/abs/1702.03118>`_ and `Swish:
389
+ a Self-Gated Activation Function <https://arxiv.org/abs/1710.05941v1>`_
390
+ where the SiLU was experimented with later.
391
+
392
+ Shape:
393
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
394
+ - Output: :math:`(*)`, same shape as the input.
395
+
396
+ .. image:: ../scripts/activation_images/SiLU.png
397
+
398
+ Examples::
399
+
400
+ >>> m = nn.SiLU()
401
+ >>> input = torch.randn(2)
402
+ >>> output = m(input)
403
+ """
404
+
405
+ __constants__ = ['inplace']
406
+ inplace: bool
407
+
408
+ def __init__(self, inplace: bool = False):
409
+ super().__init__()
410
+ self.inplace = inplace
411
+
412
+ def forward(self, input: Tensor) -> Tensor:
413
+ return F.silu(input, inplace=self.inplace)
414
+
415
+ def extra_repr(self) -> str:
416
+ inplace_str = 'inplace=True' if self.inplace else ''
417
+ return inplace_str
418
+
419
+
420
+ [docs]class Mish(Module):
421
+ r"""Applies the Mish function, element-wise.
422
+
423
+ Mish: A Self Regularized Non-Monotonic Neural Activation Function.
424
+
425
+ .. math::
426
+ \text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))
427
+
428
+ .. note::
429
+ See `Mish: A Self Regularized Non-Monotonic Neural Activation Function <https://arxiv.org/abs/1908.08681>`_
430
+
431
+ Shape:
432
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
433
+ - Output: :math:`(*)`, same shape as the input.
434
+
435
+ .. image:: ../scripts/activation_images/Mish.png
436
+
437
+ Examples::
438
+
439
+ >>> m = nn.Mish()
440
+ >>> input = torch.randn(2)
441
+ >>> output = m(input)
442
+ """
443
+
444
+ __constants__ = ['inplace']
445
+ inplace: bool
446
+
447
+ def __init__(self, inplace: bool = False):
448
+ super().__init__()
449
+ self.inplace = inplace
450
+
451
+ def forward(self, input: Tensor) -> Tensor:
452
+ return F.mish(input, inplace=self.inplace)
453
+
454
+ def extra_repr(self) -> str:
455
+ inplace_str = 'inplace=True' if self.inplace else ''
456
+ return inplace_str
457
+
458
+
459
+ [docs]class Hardswish(Module):
460
+ r"""Applies the Hardswish function, element-wise.
461
+
462
+ Method described in the paper: `Searching for MobileNetV3 <https://arxiv.org/abs/1905.02244>`_.
463
+
464
+ Hardswish is defined as:
465
+
466
+ .. math::
467
+ \text{Hardswish}(x) = \begin{cases}
468
+ 0 & \text{if~} x \le -3, \\
469
+ x & \text{if~} x \ge +3, \\
470
+ x \cdot (x + 3) /6 & \text{otherwise}
471
+ \end{cases}
472
+
473
+ Args:
474
+ inplace: can optionally do the operation in-place. Default: ``False``
475
+
476
+ Shape:
477
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
478
+ - Output: :math:`(*)`, same shape as the input.
479
+
480
+ .. image:: ../scripts/activation_images/Hardswish.png
481
+
482
+ Examples::
483
+
484
+ >>> m = nn.Hardswish()
485
+ >>> input = torch.randn(2)
486
+ >>> output = m(input)
487
+ """
488
+
489
+ __constants__ = ['inplace']
490
+
491
+ inplace: bool
492
+
493
+ def __init__(self, inplace : bool = False) -> None:
494
+ super().__init__()
495
+ self.inplace = inplace
496
+
497
+ def forward(self, input: Tensor) -> Tensor:
498
+ return F.hardswish(input, self.inplace)
499
+
500
+
501
+
502
+ [docs]class ELU(Module):
503
+ r"""Applies the Exponential Linear Unit (ELU) function, element-wise.
504
+
505
+ Method described in the paper: `Fast and Accurate Deep Network Learning by Exponential Linear
506
+ Units (ELUs) <https://arxiv.org/abs/1511.07289>`__.
507
+
508
+ ELU is defined as:
509
+
510
+ .. math::
511
+ \text{ELU}(x) = \begin{cases}
512
+ x, & \text{ if } x > 0\\
513
+ \alpha * (\exp(x) - 1), & \text{ if } x \leq 0
514
+ \end{cases}
515
+
516
+ Args:
517
+ alpha: the :math:`\alpha` value for the ELU formulation. Default: 1.0
518
+ inplace: can optionally do the operation in-place. Default: ``False``
519
+
520
+ Shape:
521
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
522
+ - Output: :math:`(*)`, same shape as the input.
523
+
524
+ .. image:: ../scripts/activation_images/ELU.png
525
+
526
+ Examples::
527
+
528
+ >>> m = nn.ELU()
529
+ >>> input = torch.randn(2)
530
+ >>> output = m(input)
531
+ """
532
+
533
+ __constants__ = ['alpha', 'inplace']
534
+ alpha: float
535
+ inplace: bool
536
+
537
+ def __init__(self, alpha: float = 1., inplace: bool = False) -> None:
538
+ super().__init__()
539
+ self.alpha = alpha
540
+ self.inplace = inplace
541
+
542
+ def forward(self, input: Tensor) -> Tensor:
543
+ return F.elu(input, self.alpha, self.inplace)
544
+
545
+ def extra_repr(self) -> str:
546
+ inplace_str = ', inplace=True' if self.inplace else ''
547
+ return f'alpha={self.alpha}{inplace_str}'
548
+
549
+
550
+
551
+ [docs]class CELU(Module):
552
+ r"""Applies the CELU function element-wise.
553
+
554
+ .. math::
555
+ \text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1))
556
+
557
+ More details can be found in the paper `Continuously Differentiable Exponential Linear Units`_ .
558
+
559
+ Args:
560
+ alpha: the :math:`\alpha` value for the CELU formulation. Default: 1.0
561
+ inplace: can optionally do the operation in-place. Default: ``False``
562
+
563
+ Shape:
564
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
565
+ - Output: :math:`(*)`, same shape as the input.
566
+
567
+ .. image:: ../scripts/activation_images/CELU.png
568
+
569
+ Examples::
570
+
571
+ >>> m = nn.CELU()
572
+ >>> input = torch.randn(2)
573
+ >>> output = m(input)
574
+
575
+ .. _`Continuously Differentiable Exponential Linear Units`:
576
+ https://arxiv.org/abs/1704.07483
577
+ """
578
+
579
+ __constants__ = ['alpha', 'inplace']
580
+ alpha: float
581
+ inplace: bool
582
+
583
+ def __init__(self, alpha: float = 1., inplace: bool = False) -> None:
584
+ super().__init__()
585
+ self.alpha = alpha
586
+ self.inplace = inplace
587
+
588
+ def forward(self, input: Tensor) -> Tensor:
589
+ return F.celu(input, self.alpha, self.inplace)
590
+
591
+ def extra_repr(self) -> str:
592
+ inplace_str = ', inplace=True' if self.inplace else ''
593
+ return f'alpha={self.alpha}{inplace_str}'
594
+
595
+
596
+
597
+ [docs]class SELU(Module):
598
+ r"""Applies the SELU function element-wise.
599
+
600
+ .. math::
601
+ \text{SELU}(x) = \text{scale} * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1)))
602
+
603
+ with :math:`\alpha = 1.6732632423543772848170429916717` and
604
+ :math:`\text{scale} = 1.0507009873554804934193349852946`.
605
+
606
+ .. warning::
607
+ When using ``kaiming_normal`` or ``kaiming_normal_`` for initialisation,
608
+ ``nonlinearity='linear'`` should be used instead of ``nonlinearity='selu'``
609
+ in order to get `Self-Normalizing Neural Networks`_.
610
+ See :func:`torch.nn.init.calculate_gain` for more information.
611
+
612
+ More details can be found in the paper `Self-Normalizing Neural Networks`_ .
613
+
614
+ Args:
615
+ inplace (bool, optional): can optionally do the operation in-place. Default: ``False``
616
+
617
+ Shape:
618
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
619
+ - Output: :math:`(*)`, same shape as the input.
620
+
621
+ .. image:: ../scripts/activation_images/SELU.png
622
+
623
+ Examples::
624
+
625
+ >>> m = nn.SELU()
626
+ >>> input = torch.randn(2)
627
+ >>> output = m(input)
628
+
629
+ .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
630
+ """
631
+
632
+ __constants__ = ['inplace']
633
+ inplace: bool
634
+
635
+ def __init__(self, inplace: bool = False) -> None:
636
+ super().__init__()
637
+ self.inplace = inplace
638
+
639
+ def forward(self, input: Tensor) -> Tensor:
640
+ return F.selu(input, self.inplace)
641
+
642
+ def extra_repr(self) -> str:
643
+ inplace_str = 'inplace=True' if self.inplace else ''
644
+ return inplace_str
645
+
646
+
647
+
648
+ [docs]class GLU(Module):
649
+ r"""Applies the gated linear unit function.
650
+
651
+ :math:`{GLU}(a, b)= a \otimes \sigma(b)` where :math:`a` is the first half
652
+ of the input matrices and :math:`b` is the second half.
653
+
654
+ Args:
655
+ dim (int): the dimension on which to split the input. Default: -1
656
+
657
+ Shape:
658
+ - Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional
659
+ dimensions
660
+ - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`
661
+
662
+ Examples::
663
+
664
+ >>> m = nn.GLU()
665
+ >>> input = torch.randn(4, 2)
666
+ >>> output = m(input)
667
+ """
668
+
669
+ __constants__ = ['dim']
670
+ dim: int
671
+
672
+ def __init__(self, dim: int = -1) -> None:
673
+ super().__init__()
674
+ self.dim = dim
675
+
676
+ def forward(self, input: Tensor) -> Tensor:
677
+ return F.glu(input, self.dim)
678
+
679
+ def extra_repr(self) -> str:
680
+ return f'dim={self.dim}'
681
+
682
+
683
+
684
+ [docs]class GELU(Module):
685
+ r"""Applies the Gaussian Error Linear Units function.
686
+
687
+ .. math:: \text{GELU}(x) = x * \Phi(x)
688
+
689
+ where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution.
690
+
691
+ When the approximate argument is 'tanh', Gelu is estimated with:
692
+
693
+ .. math:: \text{GELU}(x) = 0.5 * x * (1 + \text{Tanh}(\sqrt{2 / \pi} * (x + 0.044715 * x^3)))
694
+
695
+ Args:
696
+ approximate (str, optional): the gelu approximation algorithm to use:
697
+ ``'none'`` | ``'tanh'``. Default: ``'none'``
698
+
699
+ Shape:
700
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
701
+ - Output: :math:`(*)`, same shape as the input.
702
+
703
+ .. image:: ../scripts/activation_images/GELU.png
704
+
705
+ Examples::
706
+
707
+ >>> m = nn.GELU()
708
+ >>> input = torch.randn(2)
709
+ >>> output = m(input)
710
+ """
711
+
712
+ __constants__ = ['approximate']
713
+ approximate: str
714
+
715
+ def __init__(self, approximate: str = 'none') -> None:
716
+ super().__init__()
717
+ self.approximate = approximate
718
+
719
+ def forward(self, input: Tensor) -> Tensor:
720
+ return F.gelu(input, approximate=self.approximate)
721
+
722
+ def extra_repr(self) -> str:
723
+ return f'approximate={repr(self.approximate)}'
724
+
725
+
726
+
727
+ [docs]class Hardshrink(Module):
728
+ r"""Applies the Hard Shrinkage (Hardshrink) function element-wise.
729
+
730
+ Hardshrink is defined as:
731
+
732
+ .. math::
733
+ \text{HardShrink}(x) =
734
+ \begin{cases}
735
+ x, & \text{ if } x > \lambda \\
736
+ x, & \text{ if } x < -\lambda \\
737
+ 0, & \text{ otherwise }
738
+ \end{cases}
739
+
740
+ Args:
741
+ lambd: the :math:`\lambda` value for the Hardshrink formulation. Default: 0.5
742
+
743
+ Shape:
744
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
745
+ - Output: :math:`(*)`, same shape as the input.
746
+
747
+ .. image:: ../scripts/activation_images/Hardshrink.png
748
+
749
+ Examples::
750
+
751
+ >>> m = nn.Hardshrink()
752
+ >>> input = torch.randn(2)
753
+ >>> output = m(input)
754
+ """
755
+
756
+ __constants__ = ['lambd']
757
+ lambd: float
758
+
759
+ def __init__(self, lambd: float = 0.5) -> None:
760
+ super().__init__()
761
+ self.lambd = lambd
762
+
763
+ def forward(self, input: Tensor) -> Tensor:
764
+ return F.hardshrink(input, self.lambd)
765
+
766
+ def extra_repr(self) -> str:
767
+ return f'{self.lambd}'
768
+
769
+
770
+
771
+ [docs]class LeakyReLU(Module):
772
+ r"""Applies the LeakyReLU function element-wise.
773
+
774
+ .. math::
775
+ \text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x)
776
+
777
+
778
+ or
779
+
780
+ .. math::
781
+ \text{LeakyReLU}(x) =
782
+ \begin{cases}
783
+ x, & \text{ if } x \geq 0 \\
784
+ \text{negative\_slope} \times x, & \text{ otherwise }
785
+ \end{cases}
786
+
787
+ Args:
788
+ negative_slope: Controls the angle of the negative slope (which is used for
789
+ negative input values). Default: 1e-2
790
+ inplace: can optionally do the operation in-place. Default: ``False``
791
+
792
+ Shape:
793
+ - Input: :math:`(*)` where `*` means, any number of additional
794
+ dimensions
795
+ - Output: :math:`(*)`, same shape as the input
796
+
797
+ .. image:: ../scripts/activation_images/LeakyReLU.png
798
+
799
+ Examples::
800
+
801
+ >>> m = nn.LeakyReLU(0.1)
802
+ >>> input = torch.randn(2)
803
+ >>> output = m(input)
804
+ """
805
+
806
+ __constants__ = ['inplace', 'negative_slope']
807
+ inplace: bool
808
+ negative_slope: float
809
+
810
+ def __init__(self, negative_slope: float = 1e-2, inplace: bool = False) -> None:
811
+ super().__init__()
812
+ self.negative_slope = negative_slope
813
+ self.inplace = inplace
814
+
815
+ def forward(self, input: Tensor) -> Tensor:
816
+ return F.leaky_relu(input, self.negative_slope, self.inplace)
817
+
818
+ def extra_repr(self) -> str:
819
+ inplace_str = ', inplace=True' if self.inplace else ''
820
+ return f'negative_slope={self.negative_slope}{inplace_str}'
821
+
822
+
823
+
824
+ [docs]class LogSigmoid(Module):
825
+ r"""Applies the Logsigmoid function element-wise.
826
+
827
+ .. math::
828
+ \text{LogSigmoid}(x) = \log\left(\frac{ 1 }{ 1 + \exp(-x)}\right)
829
+
830
+ Shape:
831
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
832
+ - Output: :math:`(*)`, same shape as the input.
833
+
834
+ .. image:: ../scripts/activation_images/LogSigmoid.png
835
+
836
+ Examples::
837
+
838
+ >>> m = nn.LogSigmoid()
839
+ >>> input = torch.randn(2)
840
+ >>> output = m(input)
841
+ """
842
+
843
+ def forward(self, input: Tensor) -> Tensor:
844
+ return F.logsigmoid(input)
845
+
846
+
847
+
848
+ [docs]class Softplus(Module):
849
+ r"""Applies the Softplus function element-wise.
850
+
851
+ .. math::
852
+ \text{Softplus}(x) = \frac{1}{\beta} * \log(1 + \exp(\beta * x))
853
+
854
+ SoftPlus is a smooth approximation to the ReLU function and can be used
855
+ to constrain the output of a machine to always be positive.
856
+
857
+ For numerical stability the implementation reverts to the linear function
858
+ when :math:`input \times \beta > threshold`.
859
+
860
+ Args:
861
+ beta: the :math:`\beta` value for the Softplus formulation. Default: 1
862
+ threshold: values above this revert to a linear function. Default: 20
863
+
864
+ Shape:
865
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
866
+ - Output: :math:`(*)`, same shape as the input.
867
+
868
+ .. image:: ../scripts/activation_images/Softplus.png
869
+
870
+ Examples::
871
+
872
+ >>> m = nn.Softplus()
873
+ >>> input = torch.randn(2)
874
+ >>> output = m(input)
875
+ """
876
+
877
+ __constants__ = ['beta', 'threshold']
878
+ beta: float
879
+ threshold: float
880
+
881
+ def __init__(self, beta: float = 1.0, threshold: float = 20.0) -> None:
882
+ super().__init__()
883
+ self.beta = beta
884
+ self.threshold = threshold
885
+
886
+ def forward(self, input: Tensor) -> Tensor:
887
+ return F.softplus(input, self.beta, self.threshold)
888
+
889
+ def extra_repr(self) -> str:
890
+ return f'beta={self.beta}, threshold={self.threshold}'
891
+
892
+
893
+
894
+ [docs]class Softshrink(Module):
895
+ r"""Applies the soft shrinkage function element-wise.
896
+
897
+ .. math::
898
+ \text{SoftShrinkage}(x) =
899
+ \begin{cases}
900
+ x - \lambda, & \text{ if } x > \lambda \\
901
+ x + \lambda, & \text{ if } x < -\lambda \\
902
+ 0, & \text{ otherwise }
903
+ \end{cases}
904
+
905
+ Args:
906
+ lambd: the :math:`\lambda` (must be no less than zero) value for the Softshrink formulation. Default: 0.5
907
+
908
+ Shape:
909
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
910
+ - Output: :math:`(*)`, same shape as the input.
911
+
912
+ .. image:: ../scripts/activation_images/Softshrink.png
913
+
914
+ Examples::
915
+
916
+ >>> m = nn.Softshrink()
917
+ >>> input = torch.randn(2)
918
+ >>> output = m(input)
919
+ """
920
+
921
+ __constants__ = ['lambd']
922
+ lambd: float
923
+
924
+ def __init__(self, lambd: float = 0.5) -> None:
925
+ super().__init__()
926
+ self.lambd = lambd
927
+
928
+ def forward(self, input: Tensor) -> Tensor:
929
+ return F.softshrink(input, self.lambd)
930
+
931
+ def extra_repr(self) -> str:
932
+ return str(self.lambd)
933
+
934
+
935
+
936
+ def _check_arg_device(x: Optional[torch.Tensor]) -> bool:
937
+ if x is not None:
938
+ return x.device.type in ["cpu", "cuda", torch.utils.backend_registration._privateuse1_backend_name]
939
+ return True
940
+
941
+
942
+ def _arg_requires_grad(x: Optional[torch.Tensor]) -> bool:
943
+ if x is not None:
944
+ return x.requires_grad
945
+ return False
946
+
947
+
948
+ def _is_make_fx_tracing():
949
+ if not torch.jit.is_scripting():
950
+ torch_dispatch_mode_stack = torch.utils._python_dispatch._get_current_dispatch_mode_stack()
951
+ return any(type(x) == torch.fx.experimental.proxy_tensor.ProxyTorchDispatchMode for x in torch_dispatch_mode_stack)
952
+ else:
953
+ return False
954
+
955
+
956
+ [docs]class MultiheadAttention(Module):
957
+ r"""Allows the model to jointly attend to information from different representation subspaces.
958
+
959
+ Method described in the paper:
960
+ `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
961
+
962
+ Multi-Head Attention is defined as:
963
+
964
+ .. math::
965
+ \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
966
+
967
+ where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
968
+
969
+ ``nn.MultiHeadAttention`` will use the optimized implementations of
970
+ ``scaled_dot_product_attention()`` when possible.
971
+
972
+ In addition to support for the new ``scaled_dot_product_attention()``
973
+ function, for speeding up Inference, MHA will use
974
+ fastpath inference with support for Nested Tensors, iff:
975
+
976
+ - self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor).
977
+ - inputs are batched (3D) with ``batch_first==True``
978
+ - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad``
979
+ - training is disabled (using ``.eval()``)
980
+ - ``add_bias_kv`` is ``False``
981
+ - ``add_zero_attn`` is ``False``
982
+ - ``kdim`` and ``vdim`` are equal to ``embed_dim``
983
+ - if a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ is passed, neither ``key_padding_mask``
984
+ nor ``attn_mask`` is passed
985
+ - autocast is disabled
986
+
987
+ If the optimized inference fastpath implementation is in use, a
988
+ `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be passed for
989
+ ``query``/``key``/``value`` to represent padding more efficiently than using a
990
+ padding mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_
991
+ will be returned, and an additional speedup proportional to the fraction of the input
992
+ that is padding can be expected.
993
+
994
+ Args:
995
+ embed_dim: Total dimension of the model.
996
+ num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
997
+ across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
998
+ dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
999
+ bias: If specified, adds bias to input / output projection layers. Default: ``True``.
1000
+ add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
1001
+ add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
1002
+ Default: ``False``.
1003
+ kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
1004
+ vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
1005
+ batch_first: If ``True``, then the input and output tensors are provided
1006
+ as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
1007
+
1008
+ Examples::
1009
+
1010
+ >>> # xdoctest: +SKIP
1011
+ >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
1012
+ >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
1013
+
1014
+ .. _`FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`:
1015
+ https://arxiv.org/abs/2205.14135
1016
+
1017
+ """
1018
+
1019
+ __constants__ = ['batch_first']
1020
+ bias_k: Optional[torch.Tensor]
1021
+ bias_v: Optional[torch.Tensor]
1022
+
1023
+ def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False,
1024
+ kdim=None, vdim=None, batch_first=False, device=None, dtype=None) -> None:
1025
+ if embed_dim <= 0 or num_heads <= 0:
1026
+ raise ValueError(
1027
+ f"embed_dim and num_heads must be greater than 0,"
1028
+ f" got embed_dim={embed_dim} and num_heads={num_heads} instead"
1029
+ )
1030
+ factory_kwargs = {'device': device, 'dtype': dtype}
1031
+ super().__init__()
1032
+ self.embed_dim = embed_dim
1033
+ self.kdim = kdim if kdim is not None else embed_dim
1034
+ self.vdim = vdim if vdim is not None else embed_dim
1035
+ self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
1036
+
1037
+ self.num_heads = num_heads
1038
+ self.dropout = dropout
1039
+ self.batch_first = batch_first
1040
+ self.head_dim = embed_dim // num_heads
1041
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
1042
+
1043
+ if not self._qkv_same_embed_dim:
1044
+ self.q_proj_weight = Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs))
1045
+ self.k_proj_weight = Parameter(torch.empty((embed_dim, self.kdim), **factory_kwargs))
1046
+ self.v_proj_weight = Parameter(torch.empty((embed_dim, self.vdim), **factory_kwargs))
1047
+ self.register_parameter('in_proj_weight', None)
1048
+ else:
1049
+ self.in_proj_weight = Parameter(torch.empty((3 * embed_dim, embed_dim), **factory_kwargs))
1050
+ self.register_parameter('q_proj_weight', None)
1051
+ self.register_parameter('k_proj_weight', None)
1052
+ self.register_parameter('v_proj_weight', None)
1053
+
1054
+ if bias:
1055
+ self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
1056
+ else:
1057
+ self.register_parameter('in_proj_bias', None)
1058
+ self.out_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
1059
+
1060
+ if add_bias_kv:
1061
+ self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
1062
+ self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
1063
+ else:
1064
+ self.bias_k = self.bias_v = None
1065
+
1066
+ self.add_zero_attn = add_zero_attn
1067
+
1068
+ self._reset_parameters()
1069
+
1070
+ def _reset_parameters(self):
1071
+ if self._qkv_same_embed_dim:
1072
+ xavier_uniform_(self.in_proj_weight)
1073
+ else:
1074
+ xavier_uniform_(self.q_proj_weight)
1075
+ xavier_uniform_(self.k_proj_weight)
1076
+ xavier_uniform_(self.v_proj_weight)
1077
+
1078
+ if self.in_proj_bias is not None:
1079
+ constant_(self.in_proj_bias, 0.)
1080
+ constant_(self.out_proj.bias, 0.)
1081
+ if self.bias_k is not None:
1082
+ xavier_normal_(self.bias_k)
1083
+ if self.bias_v is not None:
1084
+ xavier_normal_(self.bias_v)
1085
+
1086
+ def __setstate__(self, state):
1087
+ # Support loading old MultiheadAttention checkpoints generated by v1.1.0
1088
+ if '_qkv_same_embed_dim' not in state:
1089
+ state['_qkv_same_embed_dim'] = True
1090
+
1091
+ super().__setstate__(state)
1092
+
1093
+ [docs] def forward(
1094
+ self,
1095
+ query: Tensor,
1096
+ key: Tensor,
1097
+ value: Tensor,
1098
+ key_padding_mask: Optional[Tensor] = None,
1099
+ need_weights: bool = True,
1100
+ attn_mask: Optional[Tensor] = None,
1101
+ average_attn_weights: bool = True,
1102
+ is_causal : bool = False) -> Tuple[Tensor, Optional[Tensor]]:
1103
+ r"""Compute attention outputs using query, key, and value embeddings.
1104
+
1105
+ Supports optional parameters for padding, masks and attention weights.
1106
+
1107
+ Args:
1108
+ query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
1109
+ or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
1110
+ :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
1111
+ Queries are compared against key-value pairs to produce the output.
1112
+ See "Attention Is All You Need" for more details.
1113
+ key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
1114
+ or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
1115
+ :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
1116
+ See "Attention Is All You Need" for more details.
1117
+ value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
1118
+ ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
1119
+ sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
1120
+ See "Attention Is All You Need" for more details.
1121
+ key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
1122
+ to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
1123
+ Binary and float masks are supported.
1124
+ For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
1125
+ the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
1126
+ need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
1127
+ Set ``need_weights=False`` to use the optimized ``scaled_dot_product_attention``
1128
+ and achieve the best performance for MHA.
1129
+ Default: ``True``.
1130
+ attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
1131
+ :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
1132
+ :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
1133
+ broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
1134
+ Binary and float masks are supported. For a binary mask, a ``True`` value indicates that the
1135
+ corresponding position is not allowed to attend. For a float mask, the mask values will be added to
1136
+ the attention weight.
1137
+ If both attn_mask and key_padding_mask are supplied, their types should match.
1138
+ average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
1139
+ heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
1140
+ effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
1141
+ is_causal: If specified, applies a causal mask as attention mask.
1142
+ Default: ``False``.
1143
+ Warning:
1144
+ ``is_causal`` provides a hint that ``attn_mask`` is the
1145
+ causal mask. Providing incorrect hints can result in
1146
+ incorrect execution, including forward and backward
1147
+ compatibility.
1148
+
1149
+ Outputs:
1150
+ - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
1151
+ :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
1152
+ where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
1153
+ embedding dimension ``embed_dim``.
1154
+ - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
1155
+ returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
1156
+ :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
1157
+ :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
1158
+ head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.
1159
+
1160
+ .. note::
1161
+ `batch_first` argument is ignored for unbatched inputs.
1162
+ """
1163
+ why_not_fast_path = ''
1164
+ if ((attn_mask is not None and torch.is_floating_point(attn_mask))
1165
+ or (key_padding_mask is not None) and torch.is_floating_point(key_padding_mask)):
1166
+ why_not_fast_path = "floating-point masks are not supported for fast path."
1167
+
1168
+ is_batched = query.dim() == 3
1169
+
1170
+ key_padding_mask = F._canonical_mask(
1171
+ mask=key_padding_mask,
1172
+ mask_name="key_padding_mask",
1173
+ other_type=F._none_or_dtype(attn_mask),
1174
+ other_name="attn_mask",
1175
+ target_type=query.dtype
1176
+ )
1177
+
1178
+ attn_mask = F._canonical_mask(
1179
+ mask=attn_mask,
1180
+ mask_name="attn_mask",
1181
+ other_type=None,
1182
+ other_name="",
1183
+ target_type=query.dtype,
1184
+ check_other=False,
1185
+ )
1186
+
1187
+ is_fastpath_enabled = torch.backends.mha.get_fastpath_enabled()
1188
+
1189
+ if not is_fastpath_enabled:
1190
+ why_not_fast_path = "torch.backends.mha.get_fastpath_enabled() was not True"
1191
+ elif not is_batched:
1192
+ why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
1193
+ elif query is not key or key is not value:
1194
+ # When lifting this restriction, don't forget to either
1195
+ # enforce that the dtypes all match or test cases where
1196
+ # they don't!
1197
+ why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
1198
+ elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
1199
+ why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
1200
+ elif self.in_proj_weight is None:
1201
+ why_not_fast_path = "in_proj_weight was None"
1202
+ elif query.dtype != self.in_proj_weight.dtype:
1203
+ # this case will fail anyway, but at least they'll get a useful error message.
1204
+ why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
1205
+ elif self.training:
1206
+ why_not_fast_path = "training is enabled"
1207
+ elif (self.num_heads % 2) != 0:
1208
+ why_not_fast_path = "self.num_heads is not even"
1209
+ elif not self.batch_first:
1210
+ why_not_fast_path = "batch_first was not True"
1211
+ elif self.bias_k is not None:
1212
+ why_not_fast_path = "self.bias_k was not None"
1213
+ elif self.bias_v is not None:
1214
+ why_not_fast_path = "self.bias_v was not None"
1215
+ elif self.add_zero_attn:
1216
+ why_not_fast_path = "add_zero_attn was enabled"
1217
+ elif not self._qkv_same_embed_dim:
1218
+ why_not_fast_path = "_qkv_same_embed_dim was not True"
1219
+ elif query.is_nested and (key_padding_mask is not None or attn_mask is not None):
1220
+ why_not_fast_path = "supplying both src_key_padding_mask and src_mask at the same time \
1221
+ is not supported with NestedTensor input"
1222
+ elif torch.is_autocast_enabled():
1223
+ why_not_fast_path = "autocast is enabled"
1224
+
1225
+ if not why_not_fast_path:
1226
+ tensor_args = (
1227
+ query,
1228
+ key,
1229
+ value,
1230
+ self.in_proj_weight,
1231
+ self.in_proj_bias,
1232
+ self.out_proj.weight,
1233
+ self.out_proj.bias,
1234
+ )
1235
+ # We have to use list comprehensions below because TorchScript does not support
1236
+ # generator expressions.
1237
+ if torch.overrides.has_torch_function(tensor_args):
1238
+ why_not_fast_path = "some Tensor argument has_torch_function"
1239
+ elif _is_make_fx_tracing():
1240
+ why_not_fast_path = "we are running make_fx tracing"
1241
+ elif not all(_check_arg_device(x) for x in tensor_args):
1242
+ why_not_fast_path = ("some Tensor argument's device is neither one of "
1243
+ f"cpu, cuda or {torch.utils.backend_registration._privateuse1_backend_name}")
1244
+ elif torch.is_grad_enabled() and any(_arg_requires_grad(x) for x in tensor_args):
1245
+ why_not_fast_path = ("grad is enabled and at least one of query or the "
1246
+ "input/output projection weights or biases requires_grad")
1247
+ if not why_not_fast_path:
1248
+ merged_mask, mask_type = self.merge_masks(attn_mask, key_padding_mask, query)
1249
+
1250
+ if self.in_proj_bias is not None and self.in_proj_weight is not None:
1251
+ return torch._native_multi_head_attention(
1252
+ query,
1253
+ key,
1254
+ value,
1255
+ self.embed_dim,
1256
+ self.num_heads,
1257
+ self.in_proj_weight,
1258
+ self.in_proj_bias,
1259
+ self.out_proj.weight,
1260
+ self.out_proj.bias,
1261
+ merged_mask,
1262
+ need_weights,
1263
+ average_attn_weights,
1264
+ mask_type)
1265
+
1266
+ any_nested = query.is_nested or key.is_nested or value.is_nested
1267
+ assert not any_nested, ("MultiheadAttention does not support NestedTensor outside of its fast path. " +
1268
+ f"The fast path was not hit because {why_not_fast_path}")
1269
+
1270
+ if self.batch_first and is_batched:
1271
+ # make sure that the transpose op does not affect the "is" property
1272
+ if key is value:
1273
+ if query is key:
1274
+ query = key = value = query.transpose(1, 0)
1275
+ else:
1276
+ query, key = (x.transpose(1, 0) for x in (query, key))
1277
+ value = key
1278
+ else:
1279
+ query, key, value = (x.transpose(1, 0) for x in (query, key, value))
1280
+
1281
+ if not self._qkv_same_embed_dim:
1282
+ attn_output, attn_output_weights = F.multi_head_attention_forward(
1283
+ query, key, value, self.embed_dim, self.num_heads,
1284
+ self.in_proj_weight, self.in_proj_bias,
1285
+ self.bias_k, self.bias_v, self.add_zero_attn,
1286
+ self.dropout, self.out_proj.weight, self.out_proj.bias,
1287
+ training=self.training,
1288
+ key_padding_mask=key_padding_mask, need_weights=need_weights,
1289
+ attn_mask=attn_mask,
1290
+ use_separate_proj_weight=True,
1291
+ q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
1292
+ v_proj_weight=self.v_proj_weight,
1293
+ average_attn_weights=average_attn_weights,
1294
+ is_causal=is_causal)
1295
+ else:
1296
+ attn_output, attn_output_weights = F.multi_head_attention_forward(
1297
+ query, key, value, self.embed_dim, self.num_heads,
1298
+ self.in_proj_weight, self.in_proj_bias,
1299
+ self.bias_k, self.bias_v, self.add_zero_attn,
1300
+ self.dropout, self.out_proj.weight, self.out_proj.bias,
1301
+ training=self.training,
1302
+ key_padding_mask=key_padding_mask,
1303
+ need_weights=need_weights,
1304
+ attn_mask=attn_mask,
1305
+ average_attn_weights=average_attn_weights,
1306
+ is_causal=is_causal)
1307
+ if self.batch_first and is_batched:
1308
+ return attn_output.transpose(1, 0), attn_output_weights
1309
+ else:
1310
+ return attn_output, attn_output_weights
1311
+
1312
+
1313
+ [docs] def merge_masks(self, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor],
1314
+ query: Tensor) -> Tuple[Optional[Tensor], Optional[int]]:
1315
+ r"""Determine mask type and combine masks if necessary.
1316
+
1317
+ If only one mask is provided, that mask
1318
+ and the corresponding mask type will be returned. If both masks are provided, they will be both
1319
+ expanded to shape ``(batch_size, num_heads, seq_len, seq_len)``, combined with logical ``or``
1320
+ and mask type 2 will be returned
1321
+ Args:
1322
+ attn_mask: attention mask of shape ``(seq_len, seq_len)``, mask type 0
1323
+ key_padding_mask: padding mask of shape ``(batch_size, seq_len)``, mask type 1
1324
+ query: query embeddings of shape ``(batch_size, seq_len, embed_dim)``
1325
+ Returns:
1326
+ merged_mask: merged mask
1327
+ mask_type: merged mask type (0, 1, or 2)
1328
+ """
1329
+ mask_type: Optional[int] = None
1330
+ merged_mask: Optional[Tensor] = None
1331
+
1332
+ if key_padding_mask is not None:
1333
+ mask_type = 1
1334
+ merged_mask = key_padding_mask
1335
+
1336
+ if attn_mask is not None:
1337
+ # In this branch query can't be a nested tensor, so it has a shape
1338
+ batch_size, seq_len, _ = query.shape
1339
+ mask_type = 2
1340
+
1341
+ # Always expands attn_mask to 4D
1342
+ if attn_mask.dim() == 3:
1343
+ attn_mask_expanded = attn_mask.view(batch_size, -1, seq_len, seq_len)
1344
+ else: # attn_mask.dim() == 2:
1345
+ attn_mask_expanded = attn_mask.view(1, 1, seq_len, seq_len).expand(batch_size, self.num_heads, -1, -1)
1346
+ merged_mask = attn_mask_expanded
1347
+
1348
+ if key_padding_mask is not None:
1349
+ key_padding_mask_expanded = key_padding_mask.view(batch_size, 1, 1, seq_len).expand(-1, self.num_heads, -1, -1)
1350
+ merged_mask = attn_mask_expanded + key_padding_mask_expanded
1351
+
1352
+ # no attn_mask and no key_padding_mask, returns None, None
1353
+ return merged_mask, mask_type
1354
+
1355
+
1356
+
1357
+ [docs]class PReLU(Module):
1358
+ r"""Applies the element-wise PReLU function.
1359
+
1360
+ .. math::
1361
+ \text{PReLU}(x) = \max(0,x) + a * \min(0,x)
1362
+
1363
+ or
1364
+
1365
+ .. math::
1366
+ \text{PReLU}(x) =
1367
+ \begin{cases}
1368
+ x, & \text{ if } x \ge 0 \\
1369
+ ax, & \text{ otherwise }
1370
+ \end{cases}
1371
+
1372
+ Here :math:`a` is a learnable parameter. When called without arguments, `nn.PReLU()` uses a single
1373
+ parameter :math:`a` across all input channels. If called with `nn.PReLU(nChannels)`,
1374
+ a separate :math:`a` is used for each input channel.
1375
+
1376
+
1377
+ .. note::
1378
+ weight decay should not be used when learning :math:`a` for good performance.
1379
+
1380
+ .. note::
1381
+ Channel dim is the 2nd dim of input. When input has dims < 2, then there is
1382
+ no channel dim and the number of channels = 1.
1383
+
1384
+ Args:
1385
+ num_parameters (int): number of :math:`a` to learn.
1386
+ Although it takes an int as input, there is only two values are legitimate:
1387
+ 1, or the number of channels at input. Default: 1
1388
+ init (float): the initial value of :math:`a`. Default: 0.25
1389
+
1390
+ Shape:
1391
+ - Input: :math:`( *)` where `*` means, any number of additional
1392
+ dimensions.
1393
+ - Output: :math:`(*)`, same shape as the input.
1394
+
1395
+ Attributes:
1396
+ weight (Tensor): the learnable weights of shape (:attr:`num_parameters`).
1397
+
1398
+ .. image:: ../scripts/activation_images/PReLU.png
1399
+
1400
+ Examples::
1401
+
1402
+ >>> m = nn.PReLU()
1403
+ >>> input = torch.randn(2)
1404
+ >>> output = m(input)
1405
+ """
1406
+
1407
+ __constants__ = ['num_parameters']
1408
+ num_parameters: int
1409
+
1410
+ def __init__(self, num_parameters: int = 1, init: float = 0.25,
1411
+ device=None, dtype=None) -> None:
1412
+ factory_kwargs = {'device': device, 'dtype': dtype}
1413
+ self.num_parameters = num_parameters
1414
+ super().__init__()
1415
+ self.init = init
1416
+ self.weight = Parameter(torch.empty(num_parameters, **factory_kwargs))
1417
+ self.reset_parameters()
1418
+
1419
+ def reset_parameters(self):
1420
+ torch.nn.init.constant_(self.weight, self.init)
1421
+
1422
+ def forward(self, input: Tensor) -> Tensor:
1423
+ return F.prelu(input, self.weight)
1424
+
1425
+ def extra_repr(self) -> str:
1426
+ return f'num_parameters={self.num_parameters}'
1427
+
1428
+
1429
+
1430
+ [docs]class Softsign(Module):
1431
+ r"""Applies the element-wise Softsign function.
1432
+
1433
+ .. math::
1434
+ \text{SoftSign}(x) = \frac{x}{ 1 + |x|}
1435
+
1436
+ Shape:
1437
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
1438
+ - Output: :math:`(*)`, same shape as the input.
1439
+
1440
+ .. image:: ../scripts/activation_images/Softsign.png
1441
+
1442
+ Examples::
1443
+
1444
+ >>> m = nn.Softsign()
1445
+ >>> input = torch.randn(2)
1446
+ >>> output = m(input)
1447
+ """
1448
+
1449
+ def forward(self, input: Tensor) -> Tensor:
1450
+ return F.softsign(input)
1451
+
1452
+
1453
+
1454
+ [docs]class Tanhshrink(Module):
1455
+ r"""Applies the element-wise Tanhshrink function.
1456
+
1457
+ .. math::
1458
+ \text{Tanhshrink}(x) = x - \tanh(x)
1459
+
1460
+ Shape:
1461
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
1462
+ - Output: :math:`(*)`, same shape as the input.
1463
+
1464
+ .. image:: ../scripts/activation_images/Tanhshrink.png
1465
+
1466
+ Examples::
1467
+
1468
+ >>> m = nn.Tanhshrink()
1469
+ >>> input = torch.randn(2)
1470
+ >>> output = m(input)
1471
+ """
1472
+
1473
+ def forward(self, input: Tensor) -> Tensor:
1474
+ return F.tanhshrink(input)
1475
+
1476
+
1477
+
1478
+ [docs]class Softmin(Module):
1479
+ r"""Applies the Softmin function to an n-dimensional input Tensor.
1480
+
1481
+ Rescales them so that the elements of the n-dimensional output Tensor
1482
+ lie in the range `[0, 1]` and sum to 1.
1483
+
1484
+ Softmin is defined as:
1485
+
1486
+ .. math::
1487
+ \text{Softmin}(x_{i}) = \frac{\exp(-x_i)}{\sum_j \exp(-x_j)}
1488
+
1489
+ Shape:
1490
+ - Input: :math:`(*)` where `*` means, any number of additional
1491
+ dimensions
1492
+ - Output: :math:`(*)`, same shape as the input
1493
+
1494
+ Args:
1495
+ dim (int): A dimension along which Softmin will be computed (so every slice
1496
+ along dim will sum to 1).
1497
+
1498
+ Returns:
1499
+ a Tensor of the same dimension and shape as the input, with
1500
+ values in the range [0, 1]
1501
+
1502
+ Examples::
1503
+
1504
+ >>> m = nn.Softmin(dim=1)
1505
+ >>> input = torch.randn(2, 3)
1506
+ >>> output = m(input)
1507
+ """
1508
+
1509
+ __constants__ = ['dim']
1510
+ dim: Optional[int]
1511
+
1512
+ def __init__(self, dim: Optional[int] = None) -> None:
1513
+ super().__init__()
1514
+ self.dim = dim
1515
+
1516
+ def __setstate__(self, state):
1517
+ super().__setstate__(state)
1518
+ if not hasattr(self, 'dim'):
1519
+ self.dim = None
1520
+
1521
+ def forward(self, input: Tensor) -> Tensor:
1522
+ return F.softmin(input, self.dim, _stacklevel=5)
1523
+
1524
+ def extra_repr(self):
1525
+ return f'dim={self.dim}'
1526
+
1527
+
1528
+ [docs]class Softmax(Module):
1529
+ r"""Applies the Softmax function to an n-dimensional input Tensor.
1530
+
1531
+ Rescales them so that the elements of the n-dimensional output Tensor
1532
+ lie in the range [0,1] and sum to 1.
1533
+
1534
+ Softmax is defined as:
1535
+
1536
+ .. math::
1537
+ \text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
1538
+
1539
+ When the input Tensor is a sparse tensor then the unspecified
1540
+ values are treated as ``-inf``.
1541
+
1542
+ Shape:
1543
+ - Input: :math:`(*)` where `*` means, any number of additional
1544
+ dimensions
1545
+ - Output: :math:`(*)`, same shape as the input
1546
+
1547
+ Returns:
1548
+ a Tensor of the same dimension and shape as the input with
1549
+ values in the range [0, 1]
1550
+
1551
+ Args:
1552
+ dim (int): A dimension along which Softmax will be computed (so every slice
1553
+ along dim will sum to 1).
1554
+
1555
+ .. note::
1556
+ This module doesn't work directly with NLLLoss,
1557
+ which expects the Log to be computed between the Softmax and itself.
1558
+ Use `LogSoftmax` instead (it's faster and has better numerical properties).
1559
+
1560
+ Examples::
1561
+
1562
+ >>> m = nn.Softmax(dim=1)
1563
+ >>> input = torch.randn(2, 3)
1564
+ >>> output = m(input)
1565
+
1566
+ """
1567
+
1568
+ __constants__ = ['dim']
1569
+ dim: Optional[int]
1570
+
1571
+ def __init__(self, dim: Optional[int] = None) -> None:
1572
+ super().__init__()
1573
+ self.dim = dim
1574
+
1575
+ def __setstate__(self, state):
1576
+ super().__setstate__(state)
1577
+ if not hasattr(self, 'dim'):
1578
+ self.dim = None
1579
+
1580
+ def forward(self, input: Tensor) -> Tensor:
1581
+ return F.softmax(input, self.dim, _stacklevel=5)
1582
+
1583
+ def extra_repr(self) -> str:
1584
+ return f'dim={self.dim}'
1585
+
1586
+
1587
+
1588
+ [docs]class Softmax2d(Module):
1589
+ r"""Applies SoftMax over features to each spatial location.
1590
+
1591
+ When given an image of ``Channels x Height x Width``, it will
1592
+ apply `Softmax` to each location :math:`(Channels, h_i, w_j)`
1593
+
1594
+ Shape:
1595
+ - Input: :math:`(N, C, H, W)` or :math:`(C, H, W)`.
1596
+ - Output: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input)
1597
+
1598
+ Returns:
1599
+ a Tensor of the same dimension and shape as the input with
1600
+ values in the range [0, 1]
1601
+
1602
+ Examples::
1603
+
1604
+ >>> m = nn.Softmax2d()
1605
+ >>> # you softmax over the 2nd dimension
1606
+ >>> input = torch.randn(2, 3, 12, 13)
1607
+ >>> output = m(input)
1608
+ """
1609
+
1610
+ def forward(self, input: Tensor) -> Tensor:
1611
+ if input.dim() not in (3, 4):
1612
+ raise ValueError(
1613
+ f"Softmax2d: expected input to be 3D or 4D, got {input.dim()}D instead"
1614
+ )
1615
+ return F.softmax(input, -3, _stacklevel=5)
1616
+
1617
+
1618
+
1619
+ [docs]class LogSoftmax(Module):
1620
+ r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional input Tensor.
1621
+
1622
+ The LogSoftmax formulation can be simplified as:
1623
+
1624
+ .. math::
1625
+ \text{LogSoftmax}(x_{i}) = \log\left(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} \right)
1626
+
1627
+ Shape:
1628
+ - Input: :math:`(*)` where `*` means, any number of additional
1629
+ dimensions
1630
+ - Output: :math:`(*)`, same shape as the input
1631
+
1632
+ Args:
1633
+ dim (int): A dimension along which LogSoftmax will be computed.
1634
+
1635
+ Returns:
1636
+ a Tensor of the same dimension and shape as the input with
1637
+ values in the range [-inf, 0)
1638
+
1639
+ Examples::
1640
+
1641
+ >>> m = nn.LogSoftmax(dim=1)
1642
+ >>> input = torch.randn(2, 3)
1643
+ >>> output = m(input)
1644
+ """
1645
+
1646
+ __constants__ = ['dim']
1647
+ dim: Optional[int]
1648
+
1649
+ def __init__(self, dim: Optional[int] = None) -> None:
1650
+ super().__init__()
1651
+ self.dim = dim
1652
+
1653
+ def __setstate__(self, state):
1654
+ super().__setstate__(state)
1655
+ if not hasattr(self, 'dim'):
1656
+ self.dim = None
1657
+
1658
+ def forward(self, input: Tensor) -> Tensor:
1659
+ return F.log_softmax(input, self.dim, _stacklevel=5)
1660
+
1661
+ def extra_repr(self):
1662
+ return f'dim={self.dim}'