|
|
|
|
|
|
|
|
|
#pragma OPENCL EXTENSION cl_khr_fp16 : enable |
|
#pragma OPENCL EXTENSION cl_khr_extended_async_copies : enable |
|
|
|
__kernel void Convolution3x3( |
|
const __global half *in_param, |
|
const __global half *out, |
|
const __global half *w, |
|
int IW, |
|
int IH, |
|
int IC, |
|
int OW, |
|
int OH, |
|
int OC, |
|
int KX, |
|
int KY, |
|
int stride_x, |
|
int stride_y, |
|
int pad_x, |
|
int pad_y, |
|
int dilation_x, |
|
int dilation_y) |
|
{ |
|
__local half in_local[8 * 1024]; |
|
__local half out_local[8 * 1024]; |
|
__local half w_local[8 * 1024]; |
|
|
|
const int sizePlane = IW * IH; |
|
event_t e1 = async_work_group_copy_2D2D( |
|
in_local, |
|
in_param + get_group_id(0) * stride_y * IW, |
|
3 * IW, |
|
IC, |
|
IW * IH - 3 * IW, |
|
0, |
|
0); |
|
wait_group_events(1, &e1); |
|
|
|
const int sizeWeight = IC * 3 * 3; |
|
e1 = async_work_group_copy(w_local, w + get_group_id(1) * sizeWeight, sizeWeight, 0); |
|
wait_group_events(1, &e1); |
|
|
|
int oh = get_global_id(0); |
|
int oc = get_global_id(1); |
|
|
|
__local half *in = (__local half *)in_local + 1; |
|
|
|
int stride; |
|
int write_output = 0; |
|
__local half *src; |
|
|
|
if ((stride_x == 1) && (stride_y == 1)) { |
|
stride = OW / 8; |
|
write_output = 1; |
|
} |
|
if ((stride_x == 2) && (stride_y == 2)) { |
|
stride = OW / 4; |
|
write_output = 2; |
|
} |
|
|
|
for (int ow = 0; ow < stride; ow++) { |
|
float8 val = {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}; |
|
for (int ic = 0; ic < IC; ++ic) { |
|
src = (__local half *)((__local half8 *)(in + ic * IW * 3) + ow); |
|
__local half *k = (__local half *)(w_local + ic * 3 * 3); |
|
|
|
half8 aux_in00 = *((__local half8 *)src - 1); |
|
half8 aux_in01 = *((__local half8 *)src + 0); |
|
half8 aux_in02 = *((__local half8 *)src + 1); |
|
half8 aux_in10 = *((__local half8 *)(src + IW) - 1); |
|
half8 aux_in11 = *((__local half8 *)(src + IW) + 0); |
|
half8 aux_in12 = *((__local half8 *)(src + IW) + 1); |
|
half8 aux_in20 = *((__local half8 *)(src + IW * 2) - 1); |
|
half8 aux_in21 = *((__local half8 *)(src + IW * 2) + 0); |
|
half8 aux_in22 = *((__local half8 *)(src + IW * 2) + 1); |
|
|
|
short8 in00 = *((short8 *)&aux_in00); |
|
short8 in01 = *((short8 *)&aux_in01); |
|
short8 in02 = *((short8 *)&aux_in02); |
|
short8 in10 = *((short8 *)&aux_in10); |
|
short8 in11 = *((short8 *)&aux_in11); |
|
short8 in12 = *((short8 *)&aux_in12); |
|
short8 in20 = *((short8 *)&aux_in20); |
|
short8 in21 = *((short8 *)&aux_in21); |
|
short8 in22 = *((short8 *)&aux_in22); |
|
|
|
short8 aux_aux00 = __builtin_shave_cmu_alignvec_rri_short8(in00, in01, 14); |
|
short8 aux_aux01 = in01; |
|
short8 aux_aux02 = __builtin_shave_cmu_alignvec_rri_short8(in01, in02, 2); |
|
short8 aux_aux10 = __builtin_shave_cmu_alignvec_rri_short8(in10, in11, 14); |
|
short8 aux_aux11 = in11; |
|
short8 aux_aux12 = __builtin_shave_cmu_alignvec_rri_short8(in11, in12, 2); |
|
short8 aux_aux20 = __builtin_shave_cmu_alignvec_rri_short8(in20, in21, 14); |
|
short8 aux_aux21 = in21; |
|
short8 aux_aux22 = __builtin_shave_cmu_alignvec_rri_short8(in21, in22, 2); |
|
|
|
half8 aux00 = *((half8 *)&aux_aux00); |
|
half8 aux01 = *((half8 *)&aux_aux01); |
|
half8 aux02 = *((half8 *)&aux_aux02); |
|
half8 aux10 = *((half8 *)&aux_aux10); |
|
half8 aux11 = *((half8 *)&aux_aux11); |
|
half8 aux12 = *((half8 *)&aux_aux12); |
|
half8 aux20 = *((half8 *)&aux_aux20); |
|
half8 aux21 = *((half8 *)&aux_aux21); |
|
half8 aux22 = *((half8 *)&aux_aux22); |
|
|
|
half8 w00 = (half8)(*(k + 0)); |
|
half8 w01 = (half8)(*(k + 1)); |
|
half8 w02 = (half8)(*(k + 2)); |
|
half8 w10 = (half8)(*(k + 3)); |
|
half8 w11 = (half8)(*(k + 4)); |
|
half8 w12 = (half8)(*(k + 5)); |
|
half8 w20 = (half8)(*(k + 6)); |
|
half8 w21 = (half8)(*(k + 7)); |
|
half8 w22 = (half8)(*(k + 8)); |
|
|
|
val += convert_float8(aux00) * convert_float8(w00); |
|
val += convert_float8(aux01) * convert_float8(w01); |
|
val += convert_float8(aux02) * convert_float8(w02); |
|
val += convert_float8(aux10) * convert_float8(w10); |
|
val += convert_float8(aux11) * convert_float8(w11); |
|
val += convert_float8(aux12) * convert_float8(w12); |
|
val += convert_float8(aux20) * convert_float8(w20); |
|
val += convert_float8(aux21) * convert_float8(w21); |
|
val += convert_float8(aux22) * convert_float8(w22); |
|
} |
|
if (write_output == 2) *((__local half4 *)(out_local) + ow) = convert_half4(val.s0246); |
|
if (write_output == 1) *((__local half8 *)(out_local) + ow) = convert_half8(val); |
|
} |
|
|
|
for (int ow = OW & ~(0x7); ow < OW; ow++) { |
|
float val = 0.0f; |
|
for (int ic = 0; ic < IC; ++ic) { |
|
for (int ky = 0; ky < 3; ++ky) { |
|
for (int kx = 0; kx < 3; ++kx) { |
|
int iw = ow * stride_x - pad_x + kx * dilation_x; |
|
int ih = oh * stride_y - pad_y + ky * dilation_y; |
|
|
|
val += convert_float(in[ic * IW * 3 + (ky * dilation_y) * IW + iw]) |
|
* convert_float(w_local[ic * 3 * 3 + ky * 3 + kx]); |
|
} |
|
} |
|
} |
|
out_local[ow] = convert_half(val); |
|
} |
|
|
|
barrier(CLK_LOCAL_MEM_FENCE); |
|
|
|
event_t e2 = async_work_group_copy( |
|
out + get_group_id(1) * OW * OH + get_group_id(0) * OW, |
|
out_local, |
|
OW, |
|
0); |
|
wait_group_events(1, &e2); |
|
} |
|
|