occlum/example/client/benchmark.py

122 lines
4.6 KiB
Python

import grpc
import tensorflow as tf
import argparse, time, grpc, asyncio
import numpy as np
from PIL import Image
from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2_grpc
class benchmark_engine(object):
def __init__(self, server, image, certificate, concurrent_num=64, response_time=10):
self.server = server
self.response_time = response_time
self.concurrent_num = concurrent_num
self.image = image
self.certificate = certificate
self.request_signatures = []
self.request_stubs = []
self.request_response_list = {}
self.__prepare__()
pass
def __prepare__(self):
for idx in range(self.concurrent_num):
# get image array
# with open(self.image, 'rb') as f:
# input_name = 'images'
# input_shape = [1]
# input_data = f.read()
# Load the image and convert to RGB
img = Image.open(self.image).convert('RGB')
img = img.resize((224,224), Image.BICUBIC)
img_array = np.array(img)
img_array = img_array.astype(np.float32) /255.0
# create request
request = predict_pb2.PredictRequest()
request.model_spec.name = 'resnet'
request.model_spec.signature_name = 'serving_default'
request.inputs['input_1'].CopyFrom(
tf.make_tensor_proto(img_array, shape=[1,224,224,3]))
self.request_signatures.append(request)
return None
async def __connection__(self, task_idx, loop_num):
request_signatures = self.request_signatures[task_idx]
response_list = []
# create channel
creds = grpc.ssl_channel_credentials(root_certificates=open(self.certificate, 'rb').read())
async with grpc.aio.secure_channel(self.server, creds) as channel:
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
format_string = 'query: {} channel, task {}, loop_idx {}, latency(ms) {:.1f}, tps: {:.1f}'
for loop_idx in range(loop_num):
start_time = time.time()
response = await stub.Predict(request_signatures)
stop_time = time.time()
latency = stop_time - start_time
tps = 1 / latency
response_list.append([response, latency])
print(format_string.format('secure', task_idx, loop_idx, 1000*latency, tps))
return response_list
def run(self, loop_num):
start_time = time.time()
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
connections = []
self.request_response_list.clear()
for idx in range(self.concurrent_num):
connections.append(asyncio.ensure_future(self.__connection__(idx, loop_num)))
loop.run_until_complete(asyncio.wait(connections))
loop.close()
stop_time = time.time()
response_list = [connections[idx].result() for idx in range(self.concurrent_num)]
request_time = 0
for c_idx in range(self.concurrent_num):
if loop_num != 0:
for l_idx in range(loop_num):
request_time += response_list[c_idx][l_idx][1]
if loop_num != 0:
e2e_time = stop_time - start_time
request_num = self.concurrent_num * loop_num
latency = request_time / request_num
tps = request_num * 1 / e2e_time
format_string = 'summary: cnum {}, e2e time(s) {}, average latency(ms) {}, tps: {}'
print(format_string.format(self.concurrent_num, e2e_time, 1000*latency, tps))
pass
def main():
benchmark_app = benchmark_engine(args.server, args.image, args.crt, args.cnum)
if args.loop == 0:
print("loop parameter needs to be bigger than 0")
return
# warm up
benchmark_app.run(5)
# start loop
benchmark_app.run(args.loop)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--server', default='localhost:9000',
help='Tenforflow Model Server Address')
parser.add_argument('--crt', default=None, type=str, help='TLS certificate file path')
parser.add_argument('--image', default='Siberian_Husky_bi-eyed_Flickr.jpg',
help='Path to the image')
parser.add_argument('--cnum', default=8, type=int, help='Concurrent connection num')
parser.add_argument('--loop', default=100, type=int, help='Requests loop num, should > 0')
args = parser.parse_args()
main()