import numpy as np
from numba import njit

@njit
def zcr_rate(wav_in, step=240, segment=960):
    # Calculate the sign of the signal values and shift to avoid zero values
    sign_wav = np.sign(wav_in + 1e-8)

    absdiff = np.abs(np.diff(sign_wav))

    for i in range(len(absdiff)):
        if absdiff[i] > 1:
            absdiff[i] = 1

    steps = (len(absdiff) - segment) // step

    zcrate = np.zeros(steps)

    for i in range(steps):
        start_idx = i * step
        end_idx = start_idx + segment
        zcrate[i] = np.mean(absdiff[start_idx:end_idx])

    return zcrate