我们遵循了以下步骤:
- 使用 5 个不同的训练集在本地机器上训练了 5 个 TensorFlow 模型。
- 以 .h5 格式保存。
- 将它们转换为 tar.gz (Model1.tar.gz,...Model5.tar.gz) 并将其上传到 S3 存储桶中。
- 使用以下代码在端点中成功部署单个模型:
from sagemaker.tensorflow import TensorFlowModel
sagemaker_model = TensorFlowModel(model_data = tarS3Path + 'model{}.tar.gz'.format(1),
role = role, framework_version='1.13',
sagemaker_session = sagemaker_session)
predictor = sagemaker_model.deploy(initial_instance_count=1,
instance_type='ml.m4.xlarge')
predictor.predict(data.values[:,0:])
输出为:{'predictions': [[153.55], [79.8196], [45.2843]]}
现在的问题是我们不能使用 5 个不同的部署语句并为 5 个模型创建 5 个不同的端点。为此,我们采用了两种方法:
i) 使用 Sagemaker 的 MultiDataModal
from sagemaker.multidatamodel import MultiDataModel
sagemaker_model1 = MultiDataModel(name = "laneMultiModels", model_data_prefix = tarS3Path,
model=sagemaker_model, #This is the same sagemaker_model which is trained above
#role = role, #framework_version='1.13',
sagemaker_session = sagemaker_session)
predictor = sagemaker_model1.deploy(initial_instance_count=1,
instance_type='ml.m4.xlarge')
predictor.predict(data.values[:,0:], target_model='model{}.tar.gz'.format(1))
这里我们在部署阶段遇到如下错误:调用 CreateModel 操作时发生错误(ValidationException):您的 Ecr Image 763104351884.dkr.ecr.us-east-2.amazonaws.com/tensorflow-inference:1.13- cpu不包含必需的 com.amazonaws.sagemaker.capabilities.multi-models=true Docker 标签。
ii) 手动创建端点
import boto3
import botocore
import sagemaker
sm_client = boto3.client('sagemaker')
image = sagemaker.image_uris.retrieve('knn','us-east-2')
container = {
"Image": image,
"ModelDataUrl": tarS3Path,
"Mode": "MultiModel"
}
# Note if I replace "knn" by tensorflow it gives an error at this stage itself
response = sm_client.create_model(
ModelName = 'multiple-tar-models',
ExecutionRoleArn = role,
Containers = [container])
response = sm_client.create_endpoint_config(
EndpointConfigName = 'multiple-tar-models-endpointconfig',
ProductionVariants=[{
'InstanceType': 'ml.t2.medium',
'InitialInstanceCount': 1,
'InitialVariantWeight': 1,
'ModelName': 'multiple-tar-models',
'VariantName': 'AllTraffic'}])
response = sm_client.create_endpoint(
EndpointName = 'tarmodels-endpoint',
EndpointConfigName = 'multiple-tar-models-endpointconfig')
在这种方法中也无法创建端点。