Deep Dive into Automated Speech Recognition: Benchmarking Whisper JAX and PyTorch Implementations Throughout Platforms
On the earth of Automated Speech Recognition (ASR), velocity and accuracy are of nice significance. The scale of the information and fashions has been rising considerably not too long ago, making it arduous to be environment friendly. Nonetheless, the race is simply beginning, and we see new developments each week. On this article, we deal with Whisper JAX, a current implementation of Whisper utilizing a special backend framework that appears to run 70 instances sooner than OpenAI’s PyTorch implementation. We examined each CPU and GPU implementations and measured accuracy and execution time. Additionally, we outlined experiments for small and large-size fashions whereas parametrizing batch measurement and information sorts to see if we may enhance it additional.
As we noticed in our previous article, Whisper is a flexible speech recognition mannequin that excels in a number of speech-processing duties. It will possibly carry out multilingual speech recognition, translation, and even voice exercise detection. It makes use of a Transformer sequence-to-sequence structure to foretell phrases and duties collectively. Whisper works as a meta-model for speech-processing duties. One of many downsides of Whisper is its effectivity; it’s typically discovered to be pretty sluggish in comparison with different state-of-the-art fashions.
Within the following sections, we undergo the small print of what modified with this new method. We examine Whisper and Whisper JAX, spotlight the primary variations between PyTorch and JAX, and develop a pipeline to guage the velocity and accuracy between each implementations.
This text belongs to “Giant Language Fashions Chronicles: Navigating the NLP Frontier”, a brand new weekly collection of articles that may discover leverage the ability of enormous fashions for varied NLP duties. By diving into these cutting-edge applied sciences, we purpose to empower builders, researchers, and fanatics to harness the potential of NLP and unlock new prospects.
Articles printed to this point:
- Summarizing the latest Spotify releases with ChatGPT
- Master Semantic Search at Scale: Index Millions of Documents with Lightning-Fast Inference Times using FAISS and Sentence Transformers
- Unlock the Power of Audio Data: Advanced Transcription and Diarization with Whisper, WhisperX, and PyAnnotate
As at all times, the code is obtainable on my Github.
The Machine Studying group extensively makes use of highly effective libraries like PyTorch and JAX. Whereas they share some similarities, their inside works are fairly totally different. Let’s perceive the primary variations.
The AI Analysis Lab at Meta developed PyTorch and actively maintains it right this moment. It’s an open-source library based mostly on the Torch library. Researchers extensively use PyTorch as a result of its dynamic computation graph, intuitive interface, and strong debugging capabilities. The truth that it makes use of dynamic graphs provides it better flexibility in constructing new fashions and simplifying the modification of such fashions throughout runtime. It’s nearer to Python and particularly to the NumPy API. The principle distinction is that we’re not working with arrays however with tensors, which might run on GPU, and helps auto differentiation.
JAX is a high-performance library developed by Google. Conversely to PyTorch, JAX combines the advantages of static and dynamic computation graphs. It does this by its just-in-time compilation characteristic, which provides flexibility and efficiency. We will consider JAX being a stack of interpreters that progressively rewrite your program. It will definitely offloads the precise computation to XLA — the Accelerated Linear Algebra compiler, additionally designed and developed by Google, to speed up Machine Studying computations.
Let’s begin by constructing a category to deal with audio transcriptions utilizing Whisper with PyTorch (OpenAI’s implementation) or Whisper with JAX. Our class is a wrapper for the fashions and an interface to simply arrange experiments. We wish to carry out a number of experiments, together with specifying the system, mannequin kind, and extra hyperparameters for Whisper JAX. Word that we used a singleton sample to make sure that as we run a number of experiences, we don’t find yourself with a number of cases of the mannequin consuming our reminiscence.
class Transcription:
"""
A category to deal with audio transcriptions utilizing both the Whisper or Whisper JAX mannequin.Attributes:
audio_file_path (str): Path to the audio file to transcribe.
model_type (str): The kind of mannequin to make use of for transcription, both "whisper" or "whisper_jax".
system (str): The system to make use of for inference (e.g., "cpu" or "cuda").
model_name (str): The particular mannequin to make use of (e.g., "base", "medium", "giant", or "large-v2").
dtype (Non-obligatory[str]): The info kind to make use of for Whisper JAX, both "bfloat16" or "bfloat32".
batch_size (Non-obligatory[int]): The batch measurement to make use of for Whisper JAX.
"""
_instance = None
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = tremendous().__new__(cls)
return cls._instance
def __init__(
self,
audio_file_path: str,
model_type: str = "whisper",
system: str = "cpu",
model_name: str = "base",
dtype: Non-obligatory[str] = None,
batch_size: Non-obligatory[int] = None,
):
self.audio_file_path = audio_file_path
self.system = system
self.model_type = model_type
self.model_name = model_name
self.dtype = dtype
self.batch_size = batch_size
self.pipeline = None
The set_pipeline
technique units up the pipeline for the required mannequin kind. Relying on the worth of the model_type
attribute, the strategy initializes the pipeline utilizing both by instantiating the FlaxWhisperPipline
class for Whisper JAX or by calling the whisper.load_model()
operate for the PyTorch implementation of Whisper.
def set_pipeline(self) -> None:
"""
Arrange the pipeline for the required mannequin kind.Returns:
None
"""
if self.model_type == "whisper_jax":
pipeline_kwargs = {}
if self.dtype:
pipeline_kwargs["dtype"] = getattr(jnp, self.dtype)
if self.batch_size:
pipeline_kwargs["batch_size"] = self.batch_size
self.pipeline = FlaxWhisperPipline(
f"openai/whisper-{self.model_name}", **pipeline_kwargs
)
elif self.model_type == "whisper":
self.pipeline = whisper.load_model(
self.model_name,
torch.system("cuda:0") if self.system == "gpu" else self.system,
)
else:
elevate ValueError(f"Invalid mannequin kind: {self.model_type}")
The run_pipeline
technique transcribes the audio file and returns the outcomes as a listing of dictionaries containing the transcribed textual content and timestamps. Within the case of Whisper JAX, it considers optionally available parameters like information kind and batch measurement, if offered. Discover which you could set return_timestamps
to False
if you’re solely thinking about getting the transcription. The mannequin output is totally different if we run the transcription course of with the PyTorch implementation. Thus, we should create a brand new object that aligns each return objects.
def run_pipeline(self) -> Record[Dict[str, Union[Tuple[float, float], str]]]:
"""
Run the transcription pipeline a second time.Returns:
A listing of dictionaries, every containing textual content and a tuple of begin and finish timestamps.
"""
if not hasattr(self, "pipeline"):
elevate ValueError("Pipeline not initialized. Name set_pipeline() first.")
if self.model_type == "whisper_jax":
outputs = self.pipeline(
self.audio_file_path, process="transcribe", return_timestamps=True
)
return outputs["chunks"]
elif self.model_type == "whisper":
outcome = self.pipeline.transcribe(self.audio_file_path)
formatted_result = [
{
"timestamp": (segment["start"], phase["end"]),
"textual content": phase["text"],
}
for phase in outcome["segments"]
]
return formatted_result
else:
elevate ValueError(f"Invalid mannequin kind: {self.model_type}")
Lastly, the transcribe_multiple()
technique permits the transcription of a number of audio information. It takes a listing of audio file paths and returns a listing of transcriptions for every audio file, the place every transcription is a listing of dictionaries containing textual content and a tuple of begin and finish timestamps.
def transcribe_multiple(
self, audio_file_paths: Record[str]
) -> Record[List[Dict[str, Union[Tuple[float, float], str]]]]:
"""
Transcribe a number of audio information utilizing the required mannequin kind.Args:
audio_file_paths (Record[str]): A listing of audio file paths to transcribe.
Returns:
Record[List[Dict[str, Union[Tuple[float, float], str]]]]: A listing of transcriptions for every audio file, the place every transcription is a listing of dictionaries containing textual content and a tuple of begin and finish timestamps.
"""
transcriptions = []
for audio_file_path in audio_file_paths:
self.audio_file_path = audio_file_path
self.set_pipeline()
transcription = self.run_pipeline()
transcriptions.append(transcription)
return transcriptions
Experimental Setup
We used a protracted audio clip with greater than half-hour to guage the efficiency of Whisper variants, with a PyTorch and JAX implementation. The researchers that developed Whisper JAX declare that the distinction is extra important when transcribing lengthy audio information.
Our experimental {hardware} setup consists of the next key parts. For the CPU, we’ve got an x86_64 structure with a complete of 112 cores, powered by an Intel(R) Xeon(R) Gold 6258R CPU operating at 2.70GHz. Relating to GPU, we use an NVIDIA Quadro RTX 8000 with 48 GB of VRAM.
Outcomes and Dialogue
On this part, we talk about the outcomes obtained from the experiments to match the efficiency of Whisper JAX and PyTorch implementations. Our outcomes present insights into the velocity and effectivity of the 2 implementations on each GPU and CPU platforms.
Our first experiment concerned operating a protracted audio (over half-hour) utilizing GPU and the bigger Whisper mannequin (large-v2) that requires roughly 10GB of VRAM. Opposite to the declare made by the authors of Whisper JAX, our outcomes point out that the JAX implementation is slower than the PyTorch model. Even with the incorporation of half-precision and batching, we couldn’t surpass the efficiency of the PyTorch implementation utilizing Whisper JAX. Whisper JAX took nearly twice the time in comparison with the PyTorch implementation to carry out an identical transcription. We additionally noticed an unusually lengthy transcription time when each half-precision and batching had been employed.
Then again, when evaluating the CPU efficiency, our outcomes present that Whisper JAX outperforms the PyTorch implementation. The speedup issue was roughly two instances sooner for Whisper JAX in comparison with the PyTorch model. We noticed this sample for the bottom and important mannequin variations.
Relating to the declare made by the authors of Whisper JAX that the second transcription needs to be a lot sooner, our experiments didn’t present supporting proof. The distinction in velocity between the primary and second transcriptions was not important. Plus, we discovered that the sample was related between each Whisper and Whisper JAX implementations.
On this article, we offered a complete evaluation of the Whisper JAX implementation, evaluating its efficiency to the unique PyTorch implementation of Whisper. Our experiments aimed to guage the claimed 70x velocity enchancment utilizing a wide range of setups, together with totally different {hardware} and hyperparameters for the Whisper JAX mannequin.
The outcomes confirmed that Whisper JAX outperformed the PyTorch implementation on CPU platforms, with a speedup issue of roughly two fold. Nonetheless, our experiments didn’t help the authors’ claims that Whisper JAX is considerably sooner on GPU platforms. Really, the PyTorch implementation carried out higher when transcribing lengthy audio information utilizing a GPU.
Moreover, we discovered no important distinction within the velocity between the primary and second transcriptions, a declare made by the Whisper JAX authors. Each implementations exhibited an identical sample on this regard.
Be in contact: LinkedIn