122 lines
4.6 KiB
Python
122 lines
4.6 KiB
Python
import grpc
|
|
import tensorflow as tf
|
|
import argparse, time, grpc, asyncio
|
|
import numpy as np
|
|
from PIL import Image
|
|
|
|
from tensorflow_serving.apis import predict_pb2
|
|
from tensorflow_serving.apis import prediction_service_pb2_grpc
|
|
|
|
|
|
class benchmark_engine(object):
|
|
def __init__(self, server, image, certificate, concurrent_num=64, response_time=10):
|
|
self.server = server
|
|
self.response_time = response_time
|
|
self.concurrent_num = concurrent_num
|
|
self.image = image
|
|
self.certificate = certificate
|
|
self.request_signatures = []
|
|
self.request_stubs = []
|
|
self.request_response_list = {}
|
|
self.__prepare__()
|
|
pass
|
|
|
|
def __prepare__(self):
|
|
for idx in range(self.concurrent_num):
|
|
# get image array
|
|
# with open(self.image, 'rb') as f:
|
|
# input_name = 'images'
|
|
# input_shape = [1]
|
|
# input_data = f.read()
|
|
|
|
# Load the image and convert to RGB
|
|
img = Image.open(self.image).convert('RGB')
|
|
img = img.resize((224,224), Image.BICUBIC)
|
|
img_array = np.array(img)
|
|
img_array = img_array.astype(np.float32) /255.0
|
|
# create request
|
|
request = predict_pb2.PredictRequest()
|
|
request.model_spec.name = 'resnet'
|
|
request.model_spec.signature_name = 'serving_default'
|
|
request.inputs['input_1'].CopyFrom(
|
|
tf.make_tensor_proto(img_array, shape=[1,224,224,3]))
|
|
|
|
self.request_signatures.append(request)
|
|
return None
|
|
|
|
async def __connection__(self, task_idx, loop_num):
|
|
request_signatures = self.request_signatures[task_idx]
|
|
response_list = []
|
|
|
|
# create channel
|
|
creds = grpc.ssl_channel_credentials(root_certificates=open(self.certificate, 'rb').read())
|
|
async with grpc.aio.secure_channel(self.server, creds) as channel:
|
|
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
|
|
format_string = 'query: {} channel, task {}, loop_idx {}, latency(ms) {:.1f}, tps: {:.1f}'
|
|
for loop_idx in range(loop_num):
|
|
start_time = time.time()
|
|
response = await stub.Predict(request_signatures)
|
|
stop_time = time.time()
|
|
latency = stop_time - start_time
|
|
tps = 1 / latency
|
|
response_list.append([response, latency])
|
|
print(format_string.format('secure', task_idx, loop_idx, 1000*latency, tps))
|
|
return response_list
|
|
|
|
def run(self, loop_num):
|
|
start_time = time.time()
|
|
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
|
|
connections = []
|
|
self.request_response_list.clear()
|
|
for idx in range(self.concurrent_num):
|
|
connections.append(asyncio.ensure_future(self.__connection__(idx, loop_num)))
|
|
|
|
loop.run_until_complete(asyncio.wait(connections))
|
|
loop.close()
|
|
|
|
stop_time = time.time()
|
|
|
|
response_list = [connections[idx].result() for idx in range(self.concurrent_num)]
|
|
|
|
request_time = 0
|
|
for c_idx in range(self.concurrent_num):
|
|
if loop_num != 0:
|
|
for l_idx in range(loop_num):
|
|
request_time += response_list[c_idx][l_idx][1]
|
|
|
|
if loop_num != 0:
|
|
e2e_time = stop_time - start_time
|
|
request_num = self.concurrent_num * loop_num
|
|
latency = request_time / request_num
|
|
tps = request_num * 1 / e2e_time
|
|
format_string = 'summary: cnum {}, e2e time(s) {}, average latency(ms) {}, tps: {}'
|
|
print(format_string.format(self.concurrent_num, e2e_time, 1000*latency, tps))
|
|
pass
|
|
|
|
def main():
|
|
benchmark_app = benchmark_engine(args.server, args.image, args.crt, args.cnum)
|
|
if args.loop == 0:
|
|
print("loop parameter needs to be bigger than 0")
|
|
return
|
|
|
|
# warm up
|
|
benchmark_app.run(5)
|
|
# start loop
|
|
benchmark_app.run(args.loop)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--server', default='localhost:9000',
|
|
help='Tenforflow Model Server Address')
|
|
parser.add_argument('--crt', default=None, type=str, help='TLS certificate file path')
|
|
parser.add_argument('--image', default='Siberian_Husky_bi-eyed_Flickr.jpg',
|
|
help='Path to the image')
|
|
parser.add_argument('--cnum', default=8, type=int, help='Concurrent connection num')
|
|
parser.add_argument('--loop', default=100, type=int, help='Requests loop num, should > 0')
|
|
args = parser.parse_args()
|
|
|
|
main() |