在单个端点的 Sagemaker 上部署多个预训练模型(tar.gz 文件)

数据挖掘 张量流 机器学习模型 aws 贤者 预训练
2022-03-09 04:23:39

我们遵循了以下步骤:

  1. 使用 5 个不同的训练集在本地机器上训练了 5 个 TensorFlow 模型。
  2. 以 .h5 格式保存。
  3. 将它们转换为 tar.gz (Model1.tar.gz,...Model5.tar.gz) 并将其上传到 S3 存储桶中。
  4. 使用以下代码在端点中成功部署单个模型:
from sagemaker.tensorflow import TensorFlowModel
sagemaker_model = TensorFlowModel(model_data = tarS3Path + 'model{}.tar.gz'.format(1),
                                  role = role, framework_version='1.13',
                                  sagemaker_session = sagemaker_session)
predictor = sagemaker_model.deploy(initial_instance_count=1,
                                   instance_type='ml.m4.xlarge')
predictor.predict(data.values[:,0:])

输出为:{'predictions': [[153.55], [79.8196], [45.2843]]}

现在的问题是我们不能使用 5 个不同的部署语句并为 5 个模型创建 5 个不同的端点。为此,我们采用了两种方法:

i) 使用 Sagemaker 的 MultiDataModal

from sagemaker.multidatamodel import MultiDataModel
sagemaker_model1 = MultiDataModel(name = "laneMultiModels", model_data_prefix = tarS3Path,
                                 model=sagemaker_model, #This is the same sagemaker_model which is trained above
                                  #role = role, #framework_version='1.13',
                                  sagemaker_session = sagemaker_session)
predictor = sagemaker_model1.deploy(initial_instance_count=1,
                                   instance_type='ml.m4.xlarge')
predictor.predict(data.values[:,0:], target_model='model{}.tar.gz'.format(1))

这里我们在部署阶段遇到如下错误:调用 CreateModel 操作时发生错误(ValidationException):您的 Ecr Image 763104351884.dkr.ecr.us-east-2.amazonaws.com/tensorflow-inference:1.13- cpu不包含必需的 com.amazonaws.sagemaker.capabilities.multi-models=true Docker 标签。

ii) 手动创建端点

import boto3
import botocore
import sagemaker
sm_client = boto3.client('sagemaker')
image = sagemaker.image_uris.retrieve('knn','us-east-2')
container = {
    "Image": image,
    "ModelDataUrl": tarS3Path,
    "Mode": "MultiModel"
}
# Note if I replace "knn" by tensorflow it gives an error at this stage itself
response = sm_client.create_model(
              ModelName        = 'multiple-tar-models',
              ExecutionRoleArn = role,
              Containers       = [container])
response = sm_client.create_endpoint_config(
    EndpointConfigName = 'multiple-tar-models-endpointconfig',
    ProductionVariants=[{
        'InstanceType':        'ml.t2.medium',
        'InitialInstanceCount': 1,
        'InitialVariantWeight': 1,
        'ModelName':            'multiple-tar-models',
        'VariantName':          'AllTraffic'}])
response = sm_client.create_endpoint(
              EndpointName       = 'tarmodels-endpoint',
              EndpointConfigName = 'multiple-tar-models-endpointconfig')

在这种方法中也无法创建端点。

2个回答

我之前也一直在寻找这方面的答案,经过几天和我朋友的尝试,我们设法做到了。我附上我们使用的一些代码片段,您可以根据您的用例进行修改

image = '763104351884.dkr.ecr.us-east-1.amazonaws.com/tensorflow-inference:2.2.0-cpu'
container = { 
    'Image': image,
    'ModelDataUrl': model_data_location,
    'Mode': 'MultiModel'
}

sagemaker_client = boto3.client('sagemaker')

# Create Model
response = sagemaker_client.create_model(
              ModelName = model_name,
              ExecutionRoleArn = role,
              Containers = [container])

# Create Endpoint Configuration
response = sagemaker_client.create_endpoint_config(
    EndpointConfigName = endpoint_configuration_name,
    ProductionVariants=[{
        'InstanceType': 'ml.t2.medium',
        'InitialInstanceCount': 1,
        'InitialVariantWeight': 1,
        'ModelName': model_name,
        'VariantName': 'AllTraffic'}])

# Create Endpoint
response = sagemaker_client.create_endpoint(
              EndpointName = endpoint_name,
              EndpointConfigName = endpoint_configuration_name)

# Invoke Endpoint
sagemaker_runtime_client = boto3.client('sagemaker-runtime')

content_type = "application/json" # The MIME type of the input data in the request body.
accept = "application/json" # The desired MIME type of the inference in the response.
payload = json.dumps({"instances": [1.0, 2.0, 5.0]}) # Payload for inference.
target_model = 'model1.tar.gz'


response = sagemaker_runtime_client.invoke_endpoint(
    EndpointName=endpoint_name, 
    ContentType=content_type,
    Accept=accept,
    Body=payload,
    TargetModel=target_model,
)

response

另外,确保您的模型tar.gz文件具有这种结构

└── model1.tar.gz
     └── <version number>
         ├── saved_model.pb
         └── variables
            └── ...

有关此的更多信息