Spaces:
Runtime error
Runtime error
fix pipeline referencement
Browse files
main.py
CHANGED
|
@@ -19,7 +19,7 @@ from typing import Dict, List, Union
|
|
| 19 |
from optimum.onnxruntime import ORTModelForSequenceClassification, ORTOptimizer, ORTQuantizer
|
| 20 |
from optimum.onnxruntime.configuration import OptimizationConfig, AutoQuantizationConfig
|
| 21 |
from optimum.pipelines import pipeline as ort_pipeline
|
| 22 |
-
from transformers import BertTokenizer, BertForSequenceClassification,
|
| 23 |
|
| 24 |
from utils import calculate_inference_time
|
| 25 |
|
|
@@ -105,7 +105,7 @@ def load_pipeline(pipeline_name: str) -> None:
|
|
| 105 |
"""
|
| 106 |
if pipeline_name == "pt_pipeline":
|
| 107 |
model = BertForSequenceClassification.from_pretrained(HUB_MODEL_PATH, num_labels=3)
|
| 108 |
-
pipeline =
|
| 109 |
elif pipeline_name == "ort_pipeline":
|
| 110 |
model = ORTModelForSequenceClassification.from_pretrained(HUB_MODEL_PATH, from_transformers=True)
|
| 111 |
if not ONNX_MODEL_PATH.exists():
|
|
@@ -120,7 +120,7 @@ def load_pipeline(pipeline_name: str) -> None:
|
|
| 120 |
model = ORTModelForSequenceClassification.from_pretrained(
|
| 121 |
OPTIMIZED_BASE_PATH, file_name=OPTIMIZED_MODEL_PATH.name
|
| 122 |
)
|
| 123 |
-
pipeline =
|
| 124 |
elif pipeline_name == "ort_quantized_pipeline":
|
| 125 |
if not QUANTIZED_MODEL_PATH.exists():
|
| 126 |
quantization_config = AutoQuantizationConfig.arm64(is_static=False, per_channel=False)
|
|
@@ -130,7 +130,7 @@ def load_pipeline(pipeline_name: str) -> None:
|
|
| 130 |
model = ORTModelForSequenceClassification.from_pretrained(
|
| 131 |
QUANTIZED_BASE_PATH, file_name=QUANTIZED_MODEL_PATH.name
|
| 132 |
)
|
| 133 |
-
pipeline =
|
| 134 |
print(type(pipeline))
|
| 135 |
return pipeline
|
| 136 |
|
|
|
|
| 19 |
from optimum.onnxruntime import ORTModelForSequenceClassification, ORTOptimizer, ORTQuantizer
|
| 20 |
from optimum.onnxruntime.configuration import OptimizationConfig, AutoQuantizationConfig
|
| 21 |
from optimum.pipelines import pipeline as ort_pipeline
|
| 22 |
+
from transformers import BertTokenizer, BertForSequenceClassification, pt_pipeline
|
| 23 |
|
| 24 |
from utils import calculate_inference_time
|
| 25 |
|
|
|
|
| 105 |
"""
|
| 106 |
if pipeline_name == "pt_pipeline":
|
| 107 |
model = BertForSequenceClassification.from_pretrained(HUB_MODEL_PATH, num_labels=3)
|
| 108 |
+
pipeline = pt_pipeline("sentiment-analysis", tokenizer=st.session_state["tokenizer"], model=model)
|
| 109 |
elif pipeline_name == "ort_pipeline":
|
| 110 |
model = ORTModelForSequenceClassification.from_pretrained(HUB_MODEL_PATH, from_transformers=True)
|
| 111 |
if not ONNX_MODEL_PATH.exists():
|
|
|
|
| 120 |
model = ORTModelForSequenceClassification.from_pretrained(
|
| 121 |
OPTIMIZED_BASE_PATH, file_name=OPTIMIZED_MODEL_PATH.name
|
| 122 |
)
|
| 123 |
+
pipeline = ort_pipeline("text-classification", tokenizer=st.session_state["tokenizer"], model=model)
|
| 124 |
elif pipeline_name == "ort_quantized_pipeline":
|
| 125 |
if not QUANTIZED_MODEL_PATH.exists():
|
| 126 |
quantization_config = AutoQuantizationConfig.arm64(is_static=False, per_channel=False)
|
|
|
|
| 130 |
model = ORTModelForSequenceClassification.from_pretrained(
|
| 131 |
QUANTIZED_BASE_PATH, file_name=QUANTIZED_MODEL_PATH.name
|
| 132 |
)
|
| 133 |
+
pipeline = ort_pipeline("text-classification", tokenizer=st.session_state["tokenizer"], model=model)
|
| 134 |
print(type(pipeline))
|
| 135 |
return pipeline
|
| 136 |
|