昨天我们实现了Tensorflow.js的花卉识别程序,它的优点是不需要服务器支持,在客户端就可以完成花卉识别,使用非常方便,但也存在一些缺点。对于很多深度学习的应用来说,由于其训练模型复杂、计算量大,所以,一般来说,仍然需要服务器支持。下面仍然以花卉识别为例,介绍如何部署Tensorflow Serving及客户端编程。
TensorFlow Serving 是由 Google 开发和维护的开源项目,是 TensorFlow 生态系统的一部分,专门用于高效地部署和服务机器学习模型,具有高性能、灵活性、易于集成、可扩展性、易于管理和健壮性等多方面的优点。最重要的是它与 TensorFlow 紧密集成,实现了与 TensorFlow 生态系统无缝集成,支持 TensorFlow 模型的完整生命周期管理,从训练到部署再到监控。并且能够直接加载和使用 TensorFlow 的 SavedModel 格式,无需额外的转换步骤。
相对于许多通用的 web 服务器和 API 服务器(如 Flask、Django、FastAPI 等),但 TensorFlow Serving 专门针对机器学习模型的服务进行了优化,包括高效的内存管理、请求批处理、多线程处理等,能够在高并发和高负载的场景下表现出色。
这里不介绍Tensorflow Serving的安装,只介绍与编程有关的部署等问题。
文末附完整源代码链接。
一、服务端部署训练模型
1. 配置模型
按如下目录存放训练SavedModel模型:
/path/to/your/model/ └── your_model/ └── 1/ ├── saved_model.pb └── variables/ ├── variables.data-00000-of-00001 └── variables.index
2. 启动 TensorFlow Serving
执行以下命令:
tensorflow_model_server --port=8500 --rest_api_port=8501 --model_name=your_model --model_base_path=/path/to/your/model/your_model
这条命令用于启动 TensorFlow Serving 服务器,加载指定的模型,并配置其服务端口和 API 端口。以下是每个参数的详细解释:
3. 命令和参数解释
tensorflow_model_server --port=8500 --rest_api_port=8501 --model_name=your_model --model_base_path=/path/to/your/model/your_model
tensorflow_model_server
: 这是启动 TensorFlow Serving 服务器的命令。--port=8500
: 指定 gRPC API 的端口号。gRPC 是一种高性能的远程过程调用(RPC)框架,适用于需要高吞吐量和低延迟的应用场景。--rest_api_port=8501
: 指定 RESTful API 的端口号。RESTful API 基于 HTTP 协议,使用起来简单且广泛应用,方便客户端通过 HTTP 请求与 TensorFlow Serving 进行交互。--model_name=your_model
: 指定模型的名称。在服务中使用这个名称来引用和请求这个模型。这个名称可以在客户端请求中用来标识和调用特定的模型。--model_base_path=/path/to/your/model/your_model
: 指定模型所在的目录路径。TensorFlow Serving 会在这个目录中查找并加载模型。该路径应包含模型的文件和子目录。
二、客户端程序
1. 使用gRPC协议访问服务器
下面的代码实现了gRPC客户端 与 TensorFlow Serving 服务器交互。客户端对图片进行预处理后,向服务器发送请求,服务器完成花卉识别后,向客户端返回结果。以下是对关键代码的解释:
(1)图像预处理函数
def process_image(image: np.ndarray) -> np.ndarray: image_tensor = tf.convert_to_tensor(image) image_resized = tf.image.resize(image_tensor, (224, 224)) image_resized /= 255 return image_resized.numpy()
- 将输入的图像数组转换为 TensorFlow 张量。
- 调整图像大小为
(224, 224)
。 - 将图像归一化到
[0, 1]
范围。 - 返回预处理后的图像数组。
(2)加载和预处理图像的函数
def load_image(image_path): im = Image.open(image_path) image_arr = np.asarray(im) processed_image = process_image(image_arr) processed_image = np.expand_dims(processed_image, 0) return processed_image
- 加载图像文件并转换为 NumPy 数组。
- 调用
process_image
进行预处理。 - 将图像扩展为
(1, 224, 224, 3)
形状,适应批处理输入。 - 返回预处理后的图像。
(3)加载标签映射的函数
def load_label_map(label_map_path): with open(label_map_path, 'r', encoding='utf-8') as f: label_map = json.load(f) return label_map
- 从 JSON 文件中加载标签映射。
- 返回标签映射的字典。
(4)创建 gRPC 频道和存根
channel = grpc.insecure_channel('your server:8500') stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
- 创建一个 gRPC 频道,连接到 TensorFlow Serving 服务。
- 创建一个存根,用于与 TensorFlow Serving 进行通信。
(5)创建预测请求
request = predict_pb2.PredictRequest() request.model_spec.name = 'ai_flower' request.model_spec.signature_name = 'serving_default'
- 创建一个
PredictRequest
对象。 - 设置模型名称
ai_flower
和签名名称serving_default
。
(6)读取和预处理图像
image_path = 'test_images/image_00250.jpg' input_image = load_image(image_path)
- 设置图像路径。
- 调用
load_image
函数读取和预处理图像。
(7)设置请求输入张量
request.inputs['keras_layer_input'].CopyFrom( tf.make_tensor_proto(input_image, shape=input_image.shape))
- 将预处理后的图像设置为请求的输入张量。
(8)发送请求并获取响应
response = stub.Predict(request)
- 发送预测请求并获取响应。
(9)提取预测结果
output_tensor_name = 'dense' # 修改为实际的键名 if output_tensor_name in response.outputs: predictions = tf.make_ndarray(response.outputs[output_tensor_name]) else: print(f"Output tensor '{output_tensor_name}' not found in the response.") predictions = []
- 假设输出张量的键名是
dense
,从响应中提取预测结果。 - 如果键名不同,请根据实际情况进行修改。
2. 使用REST API协议访问服务器
与上述使用gRPC协议访问服务器实现的功能一样。以下只对有区别代码的进行解释:
(1)服务器 URL
server_url = 'http://your_server:8501/v1/models/ai_flower:predict'
- 指定 TensorFlow Serving 服务器的 URL,发送预测请求到
ai_flower
模型的predict
端点。
(2)发送 POST 请求到服务器
response = requests.post(server_url, json=data)
- 通过
POST
请求将图像数据发送到 TensorFlow Serving 服务器。
(3)检查响应状态
if response.status_code == 200: result = response.json() predictions = np.array(result['predictions']) label_map = load_label_map('label_map.json') top_k = 5 top_indices = np.argsort(predictions[0])[-top_k:][::-1] for i in top_indices: label_id = i + 1 label_name = label_map.get(str(label_id), 'Unknown') confidence = predictions[0][i] print(f"label_id: {label_id}, Label: {label_name}, Confidence: {confidence:.4f}") else: print(f"Request failed with status code {response.status_code}") print("Response:", response.text)
- 检查响应状态码是否为
200
(即请求成功)。 - 解析响应 JSON 数据,提取预测结果。
- 加载标签映射文件。
- 获取前 5 位预测结果,打印每个预测类别的标签和信心分数。
- 如果请求失败,打印状态码和响应内容。