google-research
92 строки · 2.8 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Binary to compute table equality metrics."""
17
18from collections.abc import Sequence, Mapping19import csv20import json21import os22import zipfile23
24from absl import app25from absl import flags26import tensorflow as tf27
28from deplot import metrics29
30
31_PATH = flags.DEFINE_string(32'path', None, 'Directory containing tables')33
34_JSONL = flags.DEFINE_string(35'jsonl', None, 'JSONL directory with predictions')36
37
38def _to_markdown(bts):39reader = csv.reader(bts.decode().splitlines(), delimiter=',')40parts = ['title |'] + [' | '.join(row) for row in reader]41return '\n'.join(parts)42
43
44def _get_files(suffix):45with zipfile.ZipFile(tf.io.gfile.GFile(46f'{_PATH.value}_{suffix}.zip', 'rb')) as f:47return {os.path.basename(name): f.read(name) for name in f.namelist()48if name.endswith('.csv')}49
50
51def main(argv):52if len(argv) > 1:53raise app.UsageError('Too many command-line arguments.')54
55if _PATH.value and _JSONL.value:56raise ValueError('Only one path or value can be specified.')57
58targets, predictions = [], []59
60if _PATH.value:61targets_by_id = _get_files('targets')62predictions_by_id = _get_files('predictions')63
64with tf.io.gfile.GFile(_PATH.value + '.jsonl', 'w') as f:65for k in sorted(targets_by_id.keys()):66target = _to_markdown(targets_by_id[k])67prediction = _to_markdown(predictions_by_id[k])68targets.append([target])69predictions.append(prediction)70line = {'input': {'id': k}, 'target': target, 'prediction': prediction}71f.write(json.dumps(line) + '\n')72elif _JSONL.value:73with tf.io.gfile.GFile(_JSONL.value) as f:74for line in f:75example = json.loads(line)76targets.append(example['target'])77predictions.append(example['prediction'])78else:79raise ValueError('No input method specified.')80
81metric = {}82metric.update(metrics.table_datapoints_precision_recall(targets, predictions))83metric.update(metrics.table_number_accuracy(targets, predictions))84metric_log = json.dumps(metric, indent=2)85print(metric_log)86if _PATH.value:87with tf.io.gfile.GFile(_PATH.value + '-metrics.json', 'w') as f:88f.write(metric_log)89
90
91if __name__ == '__main__':92app.run(main)93