Spaces:
Runtime error
Runtime error
File size: 6,401 Bytes
8646273 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
import torch
import numpy as np
from scipy import signal
from scipy.signal import butter, lfilter, detrend
# Make bandpass filter
def butter_bandpass(lowcut, highcut, fs, order=5):
nyq = 0.5 * fs # Nyquist frequency
low = lowcut / nyq # Normalized frequency
high = highcut / nyq
b, a = butter(order, [low, high], btype="band") # Bandpass filter
return b, a
def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):
b, a = butter_bandpass(lowcut, highcut, fs, order=order)
y = lfilter(b, a, data)
return y
def rotate_waveform(waveform, angle):
fft_waveform = np.fft.fft(waveform) # Compute the Fourier transform of the waveform
rotate_factor = np.exp(
1j * angle
) # Create a complex exponential with the specified rotation angle
rotated_fft_waveform = (
fft_waveform * rotate_factor
) # Multiply the Fourier transform by the rotation factor
rotated_waveform = np.fft.ifft(
rotated_fft_waveform
) # Compute the inverse Fourier transform to get the rotated waveform in the time domain
return rotated_waveform
def augment(sample):
# SET PARAMETERS:
crop_length = 6000
padding = 120
test = False
waveform = sample["waveform.npy"]
meta = sample["meta.json"]
if meta["split"] != "train":
test = True
target_sample_P = meta["trace_p_arrival_sample"]
target_sample_S = meta["trace_s_arrival_sample"]
if target_sample_P is None:
target_sample_P = 0
if target_sample_S is None:
target_sample_S = 0
# Randomly select a phase to start the crop
current_phases = [x for x in (target_sample_P, target_sample_S) if x > 0]
phase_selector = np.random.randint(0, len(current_phases))
first_phase = current_phases[phase_selector]
# Shuffle
if first_phase - (crop_length - padding) > padding:
start_indx = int(
first_phase
- torch.randint(low=padding, high=(crop_length - padding), size=(1,))
)
if test == True:
start_indx = int(first_phase - 2 * padding)
elif int(first_phase - padding) > 0:
start_indx = int(
first_phase
- torch.randint(low=0, high=(int(first_phase - padding)), size=(1,))
)
if test == True:
start_indx = int(first_phase - padding)
else:
start_indx = padding
end_indx = start_indx + crop_length
if (waveform.shape[-1] - end_indx) < 0:
start_indx += waveform.shape[-1] - end_indx
end_indx = start_indx + crop_length
# Update target
new_target_P = target_sample_P - start_indx
new_target_S = target_sample_S - start_indx
# Cut
waveform_cropped = waveform[:, start_indx:end_indx]
# Preprocess
waveform_cropped = detrend(waveform_cropped)
waveform_cropped = butter_bandpass_filter(
waveform_cropped, lowcut=0.2, highcut=40, fs=100, order=5
)
window = signal.windows.tukey(waveform_cropped[-1].shape[0], alpha=0.1)
waveform_cropped = waveform_cropped * window
waveform_cropped = detrend(waveform_cropped)
if np.isnan(waveform_cropped).any() == True:
waveform_cropped = np.zeros(shape=waveform_cropped.shape)
new_target_P = 0
new_target_S = 0
if np.sum(waveform_cropped) == 0:
new_target_P = 0
new_target_S = 0
# Normalize data
max_val = np.max(np.abs(waveform_cropped))
waveform_cropped_norm = waveform_cropped / max_val
# Added Z component only
if len(waveform_cropped_norm) < 3:
zeros = np.zeros((3, waveform_cropped_norm.shape[-1]))
zeros[0] = waveform_cropped_norm
waveform_cropped_norm = zeros
if test == False:
##### Rotate waveform #####
probability = torch.randint(0, 2, size=(1,)).item()
angle = torch.FloatTensor(size=(1,)).uniform_(0.01, 359.9).item()
if probability == 1:
waveform_cropped_norm = rotate_waveform(waveform_cropped_norm, angle).real
#### Channel DropOUT #####
probability = torch.randint(0, 2, size=(1,)).item()
channel = torch.randint(1, 3, size=(1,)).item()
if probability == 1:
waveform_cropped_norm[channel, :] = 1e-6
# Normalize target
new_target_P = new_target_P / crop_length
new_target_S = new_target_S / crop_length
if (new_target_P <= 0) or (new_target_P >= 1) or (np.isnan(new_target_P)):
new_target_P = 0
if (new_target_S <= 0) or (new_target_S >= 1) or (np.isnan(new_target_S)):
new_target_S = 0
return waveform_cropped_norm, new_target_P, new_target_S
def collation_fn(sample):
waveforms = np.stack([x[0] for x in sample])
targets_P = np.stack([x[1] for x in sample])
targets_S = np.stack([x[2] for x in sample])
return (
torch.tensor(waveforms, dtype=torch.float),
torch.tensor(targets_P, dtype=torch.float),
torch.tensor(targets_S, dtype=torch.float),
)
def my_split_by_node(urls):
node_id, node_count = (
torch.distributed.get_rank(),
torch.distributed.get_world_size(),
)
return list(urls)[node_id::node_count]
def prepare_waveform(waveform):
# SET PARAMETERS:
crop_length = 6000
padding = 120
assert waveform.shape[0] <= 3, "Waveform has more than 3 channels"
if waveform.shape[-1] < crop_length:
waveform = np.pad(
waveform,
((0, 0), (0, crop_length - waveform.shape[-1])),
mode="constant",
constant_values=0,
)
if waveform.shape[-1] > crop_length:
waveform = waveform[:, :crop_length]
# Preprocess
waveform = detrend(waveform)
waveform = butter_bandpass_filter(
waveform, lowcut=0.2, highcut=40, fs=100, order=5
)
window = signal.windows.tukey(waveform[-1].shape[0], alpha=0.1)
waveform = waveform * window
waveform = detrend(waveform)
assert np.isnan(waveform).any() != True, "Nan in waveform"
assert np.sum(waveform) != 0, "Sum of waveform sample is zero"
# Normalize data
max_val = np.max(np.abs(waveform))
waveform = waveform / max_val
# Added Z component only
if len(waveform) < 3:
zeros = np.zeros((3, waveform.shape[-1]))
zeros[0] = waveform
waveform = zeros
return torch.tensor([waveform]*128, dtype=torch.float) |