| |
| import json |
| import sys |
|
|
| from pyspark.sql import SparkSession |
| from pyspark import SparkConf |
| from pyspark.sql.functions import col, udf, lit |
| from pyspark.sql.types import MapType, StringType, FloatType |
|
|
| from preprocess_content import fasttext_preprocess_func |
| from fasttext_infer import fasttext_infer |
|
|
|
|
| def get_fasttext_pred(content: str): |
| """Filter the prediction result. |
| |
| Args: |
| content (str): text. |
| |
| Returns: |
| Optional[str]: json string with pred_label and pred_score. |
| """ |
| norm_content = fasttext_preprocess_func(content) |
| label, score = fasttext_infer(norm_content) |
|
|
| if label == '__label__pos': |
| return json.dumps({'pred_label': label, 'pred_score': score}, ensure_ascii=False) |
| else: |
| return None |
|
|
| if __name__ == "__main__": |
|
|
| input_path = sys.argv[1] |
| save_path = sys.argv[2] |
|
|
| content_key = "content" |
|
|
| spark = (SparkSession.builder.enableHiveSupport() |
| .config("hive.exec.dynamic.partition", "true") |
| .config("hive.exec.dynamic.partition.mode", "nonstrict") |
| .appName("FastTextInference") |
| .getOrCreate()) |
|
|
| predict_udf = udf(get_fasttext_pred) |
|
|
| |
| df = spark.read.parquet(input_path) |
| df = df.withColumn("fasttext_pred", predict_udf(col(content_key))) |
| df = df.filter(col("fasttext_pred").isNotNull()) |
| |
| |
| df.coalesce(1000).write.mode("overwrite").parquet(save_path) |
|
|
| spark.stop() |
|
|