Edit

kc3-lang/brotli/research/find_opt_references.cc

Branch :

  • Show log

    Commit

  • Author : Evgenii Kliuchnikov
    Date : 2025-01-07 13:59:06
    Hash : 57610b71
    Message : translate includes in brotli/research PiperOrigin-RevId: 713033033

  • research/find_opt_references.cc
  • /* Copyright 2016 Google Inc. All Rights Reserved.
       Author: zip753@gmail.com (Ivan Nikulin)
    
       Distributed under MIT license.
       See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
    */
    
    /* Tool for generating optimal backward references for the input file. Uses
       sais-lite library for building suffix array. */
    
    #include <algorithm>
    #include <cassert>
    #include <cstdio>
    #include <cstring>
    #include <functional>
    #include <utility>
    #include <vector>
    
    #include <gflags/gflags.h>
    using gflags::ParseCommandLineFlags;
    
    #include "third_party/absl/flags/flag.h"
    #include <sais.hxx>
    
    DEFINE_bool(advanced, false, "Advanced searching mode: finds all longest "
        "matches at positions that are not covered by matches of length at least "
        "max_length. WARNING: uses much more memory than simple mode, especially "
        "for small values of min_length.");
    DEFINE_int32(min_length, 1, "Minimal length of found backward references.");
    /* For advanced mode. */
    DEFINE_int32(long_length, 32,
                 "Maximal length of found backward references for advanced mode.");
    DEFINE_int32(skip, 1, "Number of bytes to skip.");
    
    const size_t kFileBufferSize = (1 << 16);  // 64KB
    
    typedef int sarray_type;  // Can't make it unsigned because of templates :(
    typedef uint8_t input_type;
    typedef uint32_t lcp_type;
    typedef std::pair<int, std::vector<int> > entry_type;
    typedef std::function<void(sarray_type*, lcp_type*, size_t, int, int, int, int,
                               int)> Fn;
    
    void ReadInput(FILE* fin, input_type* storage, size_t input_size) {
      size_t last_pos = 0;
      size_t available_in;
      fseek(fin, 0, SEEK_SET);
      do {
        available_in = fread(storage + last_pos, 1, kFileBufferSize, fin);
        last_pos += available_in;
      } while (available_in != 0);
      assert(last_pos == input_size);
    }
    
    void BuildLCP(input_type* storage, sarray_type* sarray, lcp_type* lcp,
                  size_t size, uint32_t* pos) {
      for (int i = 0; i < size; ++i) {
        pos[sarray[i]] = i;
      }
      uint32_t k = 0;
      lcp[size - 1] = 0;
      for (int i = 0; i < size; ++i) {
        if (pos[i] == size - 1) {
          k = 0;
          continue;
        }
        uint32_t j = sarray[pos[i] + 1];  // Suffix which follow i-th suffix in SA.
        while (i + k < size && j + k < size && storage[i + k] == storage[j + k]) {
          ++k;
        }
        lcp[pos[i]] = k;
        if (k > 0) --k;
      }
    }
    
    inline void PrintReference(sarray_type* sarray, lcp_type* lcp, size_t size,
                               int idx, int left_ix, int right_ix, int left_lcp,
                               int right_lcp, FILE* fout) {
      int max_lcp_ix;
      if (right_ix == size - 1 || (left_ix >= 0 && left_lcp >= right_lcp)) {
        max_lcp_ix = left_ix;
      } else {
        max_lcp_ix = right_ix;
      }
      int dist = idx - sarray[max_lcp_ix];
      assert(dist > 0);
      fputc(1, fout);
      fwrite(&idx, sizeof(int), 1, fout);   // Position in input.
      fwrite(&dist, sizeof(int), 1, fout);  // Backward distance.
    }
    
    inline void GoLeft(sarray_type* sarray, lcp_type* lcp, int idx, int left_ix,
                       int left_lcp, entry_type* entry) {
      entry->first = left_lcp;
      if (left_lcp > absl::GetFlag(FLAGS_long_length)) return;
      for (; left_ix >= 0; --left_ix) {
        if (lcp[left_ix] < left_lcp) break;
        if (sarray[left_ix] < idx) {
          entry->second.push_back(idx - sarray[left_ix]);
        }
      }
    }
    
    inline void GoRight(sarray_type* sarray, lcp_type* lcp, int idx, size_t size,
                        int right_ix, int right_lcp, entry_type* entry) {
      entry->first = right_lcp;
      if (right_lcp > absl::GetFlag(FLAGS_long_length)) return;
      for (; right_ix < size - 1; ++right_ix) {
        if (lcp[right_ix] < right_lcp) break;
        if (sarray[right_ix] < idx) {
          entry->second.push_back(idx - sarray[right_ix]);
        }
      }
    }
    
    inline void StoreReference(sarray_type* sarray, lcp_type* lcp, size_t size,
                               int idx, int left_ix, int right_ix, int left_lcp,
                               int right_lcp, entry_type* entries) {
      if (right_ix == size - 1 || (left_ix >= 0 && left_lcp > right_lcp)) {
        // right is invalid or left is better
        GoLeft(sarray, lcp, idx, left_ix, left_lcp, &entries[idx]);
      } else if (left_ix < 0 || (right_ix < size - 1 && right_lcp > left_lcp)) {
        // left is invalid or right is better
        GoRight(sarray, lcp, idx, size, right_ix, right_lcp, &entries[idx]);
      } else {  // both are valid and of equal length
        GoLeft(sarray, lcp, idx, left_ix, left_lcp, &entries[idx]);
        GoRight(sarray, lcp, idx, size, right_ix, right_lcp, &entries[idx]);
      }
    }
    
    void ProcessReferences(sarray_type* sarray, lcp_type* lcp, size_t size,
                           uint32_t* pos, const Fn& process_output) {
      int min_length = absl::GetFlag(FLAGS_min_length);
      for (int idx = absl::GetFlag(FLAGS_skip); idx < size; ++idx) {
        int left_lcp = -1;
        int left_ix;
        for (left_ix = pos[idx] - 1; left_ix >= 0; --left_ix) {
          if (left_lcp == -1 || left_lcp > lcp[left_ix]) {
            left_lcp = lcp[left_ix];
          }
          if (left_lcp == 0) break;
          if (sarray[left_ix] < idx) break;
        }
    
        int right_lcp = -1;
        int right_ix;
        for (right_ix = pos[idx]; right_ix < size - 1; ++right_ix) {
          if (right_lcp == -1 || right_lcp > lcp[right_ix]) {
            right_lcp = lcp[right_ix];
          }
          // Stop if we have better result from the left side already.
          if (right_lcp < left_lcp && left_ix >= 0) break;
          if (right_lcp == 0) break;
          if (sarray[right_ix] < idx) break;
        }
    
        if ((left_ix >= 0 && left_lcp >= min_length) ||
            (right_ix < size - 1 && right_lcp >= min_length)) {
          process_output(sarray, lcp, size, idx, left_ix, right_ix, left_lcp,
                         right_lcp);
        }
      }
    }
    
    void ProcessEntries(entry_type* entries, size_t size, FILE* fout) {
      int long_length = absl::GetFlag(FLAGS_long_length);
      std::vector<std::pair<int, int> > segments;
      size_t idx;
      for (idx = 0; idx < size;) {
        entry_type& entry = entries[idx];
        if (entry.first > long_length) {
          // Add segment.
          if (segments.empty() || segments.back().second < idx) {
            segments.push_back({idx, idx + entry.first});
          } else {
            segments.back().second = idx + entry.first;
          }
        }
        ++idx;
      }
      printf("Segments generated.\n");
      size_t segments_ix = 0;
      for (idx = 0; idx < size;) {
        if (idx == segments[segments_ix].first) {
          // Skip segment.
          idx = segments[segments_ix].second;
        } else {
          for (auto& dist : entries[idx].second) {
            fputc(1, fout);
            fwrite(&idx, sizeof(int), 1, fout);   // Position in input.
            fwrite(&dist, sizeof(int), 1, fout);  // Backward distance.
          }
          ++idx;
        }
      }
    }
    
    int main(int argc, char* argv[]) {
      base::ParseCommandLine(&argc, &argv);
      if (argc != 3) {
        printf("usage: %s input_file output_file\n", argv[0]);
        return 1;
      }
    
      FILE* fin = fopen(argv[1], "rb");
      FILE* fout = fopen(argv[2], "w");
    
      fseek(fin, 0, SEEK_END);
      int input_size = ftell(fin);
      fseek(fin, 0, SEEK_SET);
      printf("The file size is %u bytes\n", input_size);
    
      input_type* storage = new input_type[input_size];
    
      ReadInput(fin, storage, input_size);
      fclose(fin);
    
      sarray_type* sarray = new sarray_type[input_size];
      saisxx(storage, sarray, input_size);
      printf("Suffix array calculated.\n");
    
      // Inverse suffix array.
      uint32_t* pos = new uint32_t[input_size];
    
      lcp_type* lcp = new lcp_type[input_size];
      BuildLCP(storage, sarray, lcp, input_size, pos);
      printf("LCP array constructed.\n");
      delete[] storage;
    
      using std::placeholders::_1;
      using std::placeholders::_2;
      using std::placeholders::_3;
      using std::placeholders::_4;
      using std::placeholders::_5;
      using std::placeholders::_6;
      using std::placeholders::_7;
      using std::placeholders::_8;
      entry_type* entries;
      if (absl::GetFlag(FLAGS_advanced)) {
        entries = new entry_type[input_size];
        for (size_t i = 0; i < input_size; ++i) entries[i].first = -1;
      }
      Fn print = std::bind(PrintReference, _1, _2, _3, _4, _5, _6, _7, _8, fout);
      Fn store = std::bind(StoreReference, _1, _2, _3, _4, _5, _6, _7, _8, entries);
    
      ProcessReferences(sarray, lcp, input_size, pos,
                        absl::GetFlag(FLAGS_advanced) ? store : print);
      printf("References processed.\n");
    
      if (absl::GetFlag(FLAGS_advanced)) {
        int good_cnt = 0;
        uint64_t avg_cnt = 0;
        for (size_t i = 0; i < input_size; ++i) {
          if (entries[i].first != -1) {
            ++good_cnt;
            avg_cnt += entries[i].second.size();
          }
        }
        printf("Number of covered positions = %d\n", good_cnt);
        printf("Average number of references per covered position = %.4lf\n",
                static_cast<double>(avg_cnt) / good_cnt);
        ProcessEntries(entries, input_size, fout);
        printf("Entries processed.\n");
      }
    
      fclose(fout);
      return 0;
    }