Logo

dev-resources.site

for different kinds of informations.

Deploying OpenAI's Whisper Large V3 Model on SageMaker Using Hugging Face Libraries

Published at
1/23/2024
Categories
whisper
huggingface
sagemaker
aws
Author
Mohamad Albaker Kawtharani
Categories
4 categories in total
whisper
open
huggingface
open
sagemaker
open
aws
open
Deploying OpenAI's Whisper Large V3 Model on SageMaker Using Hugging Face Libraries

In a recent project, I was utilizing OpenAI's Whisper model for transcription. The sprint goal was to deploy it on SageMaker, leveraging the smoothness of Hugging Face libraries. However, I encountered a block: a ModelError that puzzled me for a couple of hours.
The error in more details:

ModelError: An error occurred (ModelError) when calling the InvokeEndpoint operation: Received client error (400) from primary with message "{
"code": 400,
"type": "InternalServerException",
"message": "Wrong index found for \u003c|0.02|\u003e: should be None but found 50366."

After conducting research, I discovered a solution discussed in Issue #58 on the Hugging Face forum, within the OpenAI Whisper Large V3 repository. The solution indicates that the issue is caused by variations in the transformers libraries, and to resolve it, we need to enforce the use of a more recent version. It's important to note that the required libraries are not currently supported by the Hugging Face library (as of now).

In this blog post, I will present a straightforward method to implement this solution, whether you are utilizing a SageMaker domain or a SageMaker notebook for deploying Whisper Large models.

1. Setting Up Directory and Files

In this phase, we create the necessary directory structure and files for our Whisper model deployment.
This includes creating the whisper-model directory, the inference.py script, and the requirements.txt file.
The script inference.py sets up the model and processor configurations for the Whisper model.

import os

# Directory and file paths
dir_path = './whisper-model'
inference_file_path = os.path.join(dir_path, 'code/inference.py')
requirements_file_path = os.path.join(dir_path, 'code/requirements.txt')

# Create the directory structure
os.makedirs(os.path.dirname(inference_file_path), exist_ok=True)

# Inference.py content
inference_content = '''
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline

# Model and task specifications
model_id = "openai/whisper-large-v3"
task = "automatic-speech-recognition"

# Device configuration
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

def model_fn(model_dir):
    try:
        print(f"Loading model: {model_id}")
        # Load the model
        model = AutoModelForSpeechSeq2Seq.from_pretrained(
            model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
        )
        model.to(device)
        print(f"Model loaded on device: {device}")

        # Load the processor
        processor = AutoProcessor.from_pretrained(model_id)
        print("Processor loaded")

        # Create and return a pipeline for ASR
        asr_pipeline = pipeline(
            task,
            model=model,
            tokenizer=processor.tokenizer,
            feature_extractor=processor.feature_extractor,
            return_timestamps=True,
            torch_dtype=torch_dtype,
            device=device,
        )
        print("Pipeline created")

        return asr_pipeline
    except Exception as e:
        print(f"An error occurred: {e}")
        raise
'''

# Write the inference.py file
with open(inference_file_path, 'w') as file:
    file.write(inference_content)

# Requirements.txt content
requirements_content = '''
transformers==4.38.0
accelerate==0.26.1
'''

# Write the requirements.txt file
with open(requirements_file_path, 'w') as file:
    file.write(requirements_content)

2. Archiving the Directory

In this phase, we archive the entire whisper-model directory into a compressed file using the make_archive function from shutil.
This compressed file is prepared for deployment to SageMaker.

import shutil
shutil.make_archive('./whisper-model', 'gztar', './whisper-model')

3. Uploading the Model to S3

This phase involves uploading the Whisper model, which is now in a compressed format, to Amazon S3 bucket.
We utilize SageMaker's capabilities to interact with S3 for efficient storage and retrieval.

import sagemaker
import boto3

# Get the SageMaker session and default S3 bucket
sagemaker_session = sagemaker.Session()
bucket = sagemaker_session.default_bucket() # Change if you want to store in a different bucket
prefix = 'whisper/code'

# Upload the model to S3
s3_path = sagemaker_session.upload_data(
    'whisper-model.tar.gz', 
    bucket=bucket,
    key_prefix=prefix
)

print(f"Model uploaded to {s3_path}")

4. Deploying the Model on SageMaker

Here, we deploy the Whisper model on SageMaker using the Hugging Face Model Class.
We specify the model's version, PyTorch version, instance type, and other parameters to ensure smooth deployment as an inference endpoint.

from sagemaker.huggingface import HuggingFaceModel
import sagemaker

role = sagemaker.get_execution_role()

# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
    transformers_version='4.26.0',
    pytorch_version='1.13.1',
    py_version='py39',
    model_data=s3_path,
    role=role,
)
# deploy model to SageMaker Inference
predictor = huggingface_model.deploy(
    initial_instance_count=1,
    instance_type="ml.g5.2xlarge"
)

5. Making Predictions with the Deployed Model

In this final phase, we configure the deployed model to handle audio input data.
We specify the data serializer for audio and demonstrate how to use the deployed model for making predictions, such as transcribing speech from audio files.

from sagemaker.serializers import DataSerializer

predictor.serializer = DataSerializer(content_type='audio/x-audio')

# Make sure the input file "sample1.flac" exists
with open("sample.wav", "rb") as f:
    data = f.read()
predictor.predict(data)

Hope it was helpful!!

Featured ones: