Source code for pydrobert.torch._parsing

# Copyright 2022 Sean Robertson

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
import warnings
import math

from typing import (
    Dict,
    Iterable,
    List,
    Mapping,
    Optional,
    Sequence,
    TextIO,
    Tuple,
    Type,
    TypeVar,
    Union,
    Union,
)
from logging import Logger
from collections import OrderedDict

import torch
import numpy as np

import pydrobert.torch.config as config

from ._textgrid import TextGrid, TEXTTIER

K = TypeVar("K", bound=Union[str, int, np.signedinteger])
F = TypeVar("F", bound=Union[float, np.floating])


[docs] def parse_arpa_lm( file_: Union[TextIO, str], token2id: Optional[Dict[str, np.signedinteger]] = None, to_base_e: bool = None, ftype: Type[F] = float, logger: Optional[Logger] = None, ) -> List[Dict[Union[K, Tuple[K, ...]], F]]: r"""Parse an ARPA statistical language model An `ARPA language model <https://cmusphinx.github.io/wiki/arpaformat/>`__ is an n-gram model with back-off probabilities. It is formatted as:: \data\ ngram 1=<count> ngram 2=<count> ... ngram <N>=<count> \1-grams: <logp> <token[t]> <logb> <logp> <token[t]> <logb> ... \2-grams: <logp> <token[t-1]> <token[t]> <logb> ... \<N>-grams: <logp> <token[t-<N>+1]> ... <token[t]> ... \end\ Parameters ---------- file_ Either the path or a file pointer to the file. token2id A dictionary whose keys are token strings and values are ids. If set, tokens will be replaced with ids on read to_base_e ARPA files store log-probabilities and log-backoffs in base-10. This ftype The floating-point type to store log-probabilities and backoffs as logger If specified, progress will be written to this logger at INFO level Returns ------- prob_dicts : list A list of the same length as there are orders of n-grams in the file (e.g. if the file contains up to tri-gram probabilities then `prob_dicts` will be of length 3). Each element is a dictionary whose key is the word sequence (earliest word first). For 1-grams, this is just the word. For n > 1, this is a tuple of words. Values are either a tuple of ``logp, logb`` of the log-probability and backoff log-probability, or, in the case of the highest-order n-grams that don't need a backoff, just the log probability. Warnings -------- Version ``0.3.0`` and prior do not have the option `to_base_e`, always returning values in log base 10. While this remains the default, it is deprecated and will be removed in a later version. This function is not safe for JIT scripting or tracing. """ if isinstance(file_, str): with open(file_) as f: return parse_arpa_lm(f, token2id, to_base_e, ftype, logger) if to_base_e is None: warnings.warn( "The default of to_base_e will be changed to True in a later version. " "Please manually specify this argument to suppress this warning" ) to_base_e = False norm = math.log10(math.e) if to_base_e else 1.0 norm = ftype(norm) if logger is None: print_ = lambda x: None else: print_ = logger.info line = "" print_("finding \\data\\ header") for line in file_: if line.strip() == "\\data\\": break if line.strip() != "\\data\\": raise IOError("Could not find \\data\\ line. Is this an ARPA file?") ngram_counts: List[Dict[int, int]] = [] count_pattern = re.compile(r"^ngram\s+(\d+)\s*=\s*(\d+)$") print_("finding n-gram counts") for line in file_: line = line.strip() if not line: continue match = count_pattern.match(line) if match is None: break n, count = (int(x) for x in match.groups()) print_(f"there are {count} {n}-grams") if len(ngram_counts) < n: ngram_counts.extend(0 for _ in range(n - len(ngram_counts))) ngram_counts[n - 1] = count prob_dicts: List[Dict[Union[K, Tuple[K, ...]], F]] = [dict() for _ in ngram_counts] ngram_header_pattern = re.compile(r"^\\(\d+)-grams:$") ngram_entry_pattern = re.compile(r"^(-?\d+(?:\.\d+)?(?:[Ee]-?\d+)?)\s+(.*)$") while line != "\\end\\": match = ngram_header_pattern.match(line) if match is None: raise IOError('line "{}" is not valid'.format(line)) ngram = int(match.group(1)) if ngram > len(ngram_counts): raise IOError( "{}-grams count was not listed, but found entry" "".format(ngram) ) dict_ = prob_dicts[ngram - 1] for line in file_: line = line.strip() if not line: continue match = ngram_entry_pattern.match(line) if match is None: break logp, rest = match.groups() tokens = tuple(rest.strip().split()) # IRSTLM and SRILM allow for implicit backoffs on non-final # n-grams, but final n-grams must not have backoffs logb = ftype(0.0) if len(tokens) == ngram + 1 and ngram < len(prob_dicts): try: logb = ftype(tokens[-1]) tokens = tokens[:-1] except ValueError: pass if len(tokens) != ngram: raise IOError( 'expected line "{}" to be a(n) {}-gram' "".format(line, ngram) ) if token2id is not None: tokens = tuple(token2id[tok] for tok in tokens) if ngram == 1: tokens = tokens[0] if ngram != len(ngram_counts): dict_[tokens] = (ftype(logp) / norm, logb / norm) else: dict_[tokens] = ftype(logp) / norm if line != "\\end\\": raise IOError("Could not find \\end\\ line") for ngram_m1, (ngram_count, dict_) in enumerate(zip(ngram_counts, prob_dicts)): if len(dict_) != ngram_count: raise IOError(f"Expected {ngram_count} {ngram_m1}-grams, got {len(dict_)}") return prob_dicts
class _AltTree(object): def __init__(self, parent=None): self.parent = parent self.tokens = [] if parent is not None: parent.tokens.append([self.tokens]) def new_branch(self): assert self.parent self.tokens = [] self.parent.tokens[-1].append(self.tokens) def _trn_line_to_transcript(x: Tuple[str, bool]) -> Optional[Tuple[str, List[str]]]: line, warn = x line = line.strip() if not line: return None try: last_open = line.rindex("(") last_close = line.rindex(")") if last_open > last_close: raise ValueError() except ValueError: raise IOError("Line does not end in utterance id") utt_id = line[last_open + 1 : last_close] line = line[:last_open].strip() transcript = [] token = "" alt_tree = _AltTree() found_alt = False while len(line): c = line[0] line = line[1:] if c == "{": found_alt = True if token: if alt_tree.parent is None: transcript.append(token) else: alt_tree.tokens.append(token) token = "" alt_tree = _AltTree(alt_tree) elif c == "/" and alt_tree.parent is not None: if token: alt_tree.tokens.append(token) token = "" alt_tree.new_branch() elif c == "}" and alt_tree.parent is not None: if token: alt_tree.tokens.append(token) token = "" if not alt_tree.tokens: raise IOError('Empty alternate found ("{ }")') alt_tree = alt_tree.parent if alt_tree.parent is None: assert len(alt_tree.tokens) == 1 transcript.append((alt_tree.tokens[0], -1, -1)) alt_tree.tokens = [] elif c == " ": if token: if alt_tree.parent is None: transcript.append(token) else: alt_tree.tokens.append(token) token = "" else: token += c if token and alt_tree.parent is None: transcript.append(token) if found_alt and warn: warnings.warn( 'Found an alternate in transcription for utt="{}". ' "Transcript will contain an array of alternates at that " "point, and will not be compatible with transcript_to_token " "until resolved. To suppress this warning, set warn=False" "".format(utt_id) ) return utt_id, transcript
[docs] def read_trn_iter( trn: Union[TextIO, str], warn: bool = True, processes: int = 0, chunk_size: int = config.DEFT_CHUNK_SIZE, ) -> Tuple[str, List[str]]: """Read a NIST sclite transcript file, yielding individual transcripts Identical to :func:`read_trn`, but yields individual transcript entries rather than a full list. Ideal for large transcript files. Parameters ---------- trn warn processes chunk_size Yields ------ utt_id : str transcript : list of str """ # implementation note: there's a lot of weirdness here. I'm trying to # match sclite's behaviour. A few things # - the last parentheses are always the utterance. Everything else is # the word # - An unmatched '}' is treated as a word # - A '/' not within curly braces is a word # - If the utterance ends without closing its alternate, the alternate is # discarded # - Comments from other formats are not comments here... # - ...but everything passed the last pair of parentheses is ignored... # - ...and internal parentheses are treated as words # - Spaces are treated as part of the utterance id # - Seg faults on empty alternates if isinstance(trn, str): with open(trn) as trn: yield from read_trn_iter(trn, warn, processes) elif processes == 0: for line in trn: x = _trn_line_to_transcript((line, warn)) if x is not None: yield x else: with torch.multiprocessing.Pool(processes) as pool: transcripts = pool.imap( _trn_line_to_transcript, ((line, warn) for line in trn), chunk_size ) for x in transcripts: if x is not None: yield x pool.close() pool.join()
[docs] def read_trn( trn: Union[TextIO, str], warn: bool = True, processes: int = 0, chunk_size: int = config.DEFT_CHUNK_SIZE, ) -> List[Tuple[str, List[str]]]: """Read a NIST sclite transcript file into a list of transcripts `sclite <http://www1.icsi.berkeley.edu/Speech/docs/sctk-1.2/sclite.htm>`__ is a commonly used scoring tool for ASR. This function converts a transcript input file ("trn" format) into a list of `transcripts`, where each element is a tuple of ``utt_id, transcript``. ``transcript`` is a list split by spaces. Parameters ---------- trn The transcript input file. Will open if `trn` is a path. warn The "trn" format uses curly braces and forward slashes to indicate transcript alterations. This is largely for scoring purposes, such as swapping between filled pauses, not for training. If `warn` is :obj:`True`, a warning will be issued via the ``warnings`` module every time an alteration appears in the "trn" file. Alterations appear in `transcripts` as elements of ``([[alt_1_word_1, alt_1_word_2, ...], [alt_2_word_1, alt_2_word_2, ...], ...], -1, -1)`` so that :func:`transcript_to_token` will not attempt to process alterations as token start and end times. processes The number of processes used to parse the lines of the trn file. If ``0``, will be performed on the main thread. Otherwise, the file will be read on the main thread and parsed using `processes` many processes. chunk_size The number of lines to be processed by a worker process at a time. Applicable when ``processes > 0`` Returns ------- transcripts : list A list of pairs ``utt_id, transcript`` where `utt_id` is a string identifying the utterance and `transcript` is a list of tokens in the utterance's transcript. Notes ----- Any null words (``@``) in the "trn" file are encoded verbatim. """ return list(read_trn_iter(trn, warn, processes, chunk_size))
[docs] def write_trn( transcripts: Iterable[Tuple[str, List[str]]], trn: Union[str, TextIO] ) -> None: """From an iterable of transcripts, write to a NIST "trn" file This is largely the inverse operation of :func:`read_trn`. In general, elements of a transcript (`transcripts` contains pairs of ``utt_id, transcript``) could be tokens or tuples of ``x, start, end`` (providing the start and end times of tokens, respectively). However, ``start`` and ``end`` are ignored when writing "trn" files. ``x`` could be the token or a list of alternates, as described in :func:`read_trn`. Parameters ---------- transcripts trn """ if isinstance(trn, str): with open(trn, "w") as trn: return write_trn(transcripts, trn) def _handle_x(x): if isinstance(x, str): return x + " " # x was a token # x is a list of alternates ret = [] for alts in x: elem = "" for xx in alts: elem += _handle_x(xx) ret.append(elem) ret = "{ " + "/ ".join(ret) + "} " return ret for utt_id, transcript in transcripts: line = "" for x in transcript: # first get rid of starts and ends, if possible. This is not # ambiguous with numerical alternates, since alternates should # always be strings and, more importantly, always have placeholder # start and end values try: if len(x) == 3 and np.isreal(x[1]) and np.isreal(x[2]): x = x[0] except TypeError: pass line += _handle_x(x) trn.write(line) trn.write("(") trn.write(utt_id) trn.write(")\n")
[docs] def read_ctm( ctm: Union[TextIO, str], wc2utt: Optional[dict] = None ) -> List[Tuple[str, List[Tuple[str, float, float]]]]: """Read a NIST sclite "ctm" file into a list of transcriptions `sclite <http://www1.icsi.berkeley.edu/Speech/docs/sctk-1.2/sclite.htm>`__ is a commonly used scoring tool for ASR. This function converts a time-marked conversation file ("ctm" format) into a list of `transcripts`. Each element is a tuple of ``utt_id, transcript``, where ``transcript`` is itself a list of triples ``token, start, end``, ``token`` being a string, ``start`` being the start time of the token (in seconds), and ``end`` being the end time of the token (in seconds) Parameters ---------- ctm The time-marked conversation file pointer. Will open if `ctm` is a path wc2utt "ctm" files identify utterances by waveform file name and channel. If specified, `wc2utt` consists of keys ``wfn, chan`` (e.g. ``'940328', 'A'``) to unique utterance IDs. If `wc2utt` is unspecified, the waveform file names are treated as the utterance IDs, and the channel is ignored Returns ------- transcripts : list Each element is a tuple of ``utt_id, transcript``. `utt_id` is a string identifying the utterance. `transcript` is a list of triples ``token, start, end``, `token` being the token (a string), `start` being a float of the start time of the token (in seconds), and `end` being the end time of the token. Notes ----- "ctm", like "trn", has "support" for alternate transcriptions. It is, as of sclite version 2.10, very buggy. For example, it cannot handle multiple alternates in the same utterance. Plus, tools like `Kaldi <http://kaldi-asr.org/>`__ use the Unix command that the sclite documentation recommends to sort a ctm, ``sort +0 -1 +1 -2 +2nb -3``, which does not maintain proper ordering for alternate delimiters. Thus, :func:`read_ctm` will error if it comes across those delimiters """ if isinstance(ctm, str): with open(ctm, "r") as ctm: return read_ctm(ctm, wc2utt) transcripts = OrderedDict() for line_no, line in enumerate(ctm): line = line.split(";;")[0].strip() if not line: continue line = line.split() try: if len(line) not in {5, 6}: raise ValueError() wfn, chan, start, dur, token = line[:5] if wc2utt is None: utt_id = wfn else: utt_id = wc2utt[(wfn, chan)] start = float(start) end = start + float(dur) if start < 0.0 or start > end: raise ValueError() transcripts.setdefault(utt_id, []).append((token, start, end)) except ValueError: raise ValueError("Could not parse line {} of ctm".format(line_no + 1)) except KeyError: raise KeyError( "ctm line {}: ({}, {}) was not found in wc2utt".format( line_no, wfn, chan ) ) return [ (utt_id, sorted(transcript, key=lambda x: x[1])) for utt_id, transcript in list(transcripts.items()) ]
def write_ctm( transcripts: Sequence[Tuple[str, Sequence[Tuple[str, float, float]]]], ctm: Union[TextIO, str], utt2wc: Union[Mapping[str, Tuple[str, str]], str] = config.DEFT_CTM_CHANNEL, ) -> None: f"""From a list of transcripts, write to a NIST "ctm" file This is the inverse operation of :func:`read_ctm`. For each element of ``transcript`` within the ``utt_id, transcript`` pairs of elements in `transcripts`, ``token, start, end``, ``start`` and ``end`` must be non-negative Parameters ---------- transcripts ctm utt2wc "ctm" files identify utterances by waveform file name and channel. If specified as a dict, `utt2wc` consists of utterance IDs as keys, and wavefile name and channels as values ``wfn, chan`` (e.g. ``'940328', '{config.DEFT_CTM_CHANNEL}'``). If `utt2wc` is a string, each utterance IDs will be mapped to ``wfn`` and `utt2wc` as the channel. """ if isinstance(ctm, str): with open(ctm, "w") as ctm: return write_ctm(transcripts, ctm, utt2wc) is_dict = not isinstance(utt2wc, str) segments = [] for utt_id, transcript in transcripts: try: wfn, chan = utt2wc[utt_id] if is_dict else (utt_id, utt2wc) except KeyError: raise KeyError('Utt "{}" has no value in utt2wc'.format(utt_id)) for tup in transcript: if isinstance(tup, str) or len(tup) != 3 or tup[1] < 0.0 or tup[2] < 0.0: raise ValueError( 'Utt "{}" contains token "{}" with no timing info' "".format(utt_id, tup) ) token, start, end = tup duration = end - start if duration < 0.0: raise ValueError( 'Utt "{}" contains token with negative duration' "".format(utt_id) ) segments.append((wfn, chan, start, duration, token)) segments = sorted(segments) for segment in segments: ctm.write("{} {} {} {} {}\n".format(*segment))
[docs] def read_textgrid( tg: Union[TextIO, str], tier_id: Union[str, int] = config.DEFT_TEXTGRID_TIER_ID, fill_token: Optional[str] = None, ) -> Tuple[List[Tuple[str, float, float]], float, float]: """Read TextGrid file as a transcription TextGrid is the transcription format of `Praat <https://www.fon.hum.uva.nl/praat/>`_. Parameters ---------- tg The TextGrid file. Will open if `tg` is a path. tier_id Either the name of the tier (first occurence) or the index of the tier to extract. fill_token If set, any intervals missing from the tier will be filled with an interval of this token before being returned. Returns ------- transcript : list A list of triples of ``token, start, end``, `token` being the token (a string), `start` being a float of the start time of the token (in seconds), and `end` being the end time of the token. If the tier is a PointTier, the start and end times will be the same. start_time : float The start time of the tier (in seconds) end_time : float The end time of the tier (in seconds) Notes ----- This function does not check for whitespace in or around token labels. This may cause issues if writing as another file type, like :func:`write_trn`. Start and end times (including any filled intervals) are determined from the tier's values, not necessarily those of the top-level container. This is most likely a technicality, however: they should not differ normally. """ if isinstance(tg, str): with open(tg) as f: return read_textgrid(f, tier_id, fill_token) tg_ = TextGrid(tg.read()) if isinstance(tier_id, str): tier = None for tier_ in tg_.tiers: if tier_.nameid == tier_id: tier = tier_ break if tier is None: raise ValueError(f"Could not find tier '{tier_id}'") else: tier = tg_.tiers[tier_id] if tier.classid == TEXTTIER: transcript = [ (x[1], float(x[0]), float(x[0])) for x in sorted(tier.simple_transcript) ] else: transcript = [ (x[2], float(x[0]), float(x[1])) for x in sorted(tier.simple_transcript) ] i = 0 start_time = tier.xmin while i < len(transcript): _, next_start, end_time = transcript[i] if fill_token is not None and start_time < next_start: transcript.insert(i, (fill_token, start_time, next_start)) i += 1 i += 1 start_time = end_time if fill_token is not None and tier.xmax is not None and start_time < tier.xmax: transcript.append((fill_token, start_time, tier.xmax)) return transcript, tier.xmin, tier.xmax
[docs] def write_textgrid( transcript: Sequence[Tuple[str, float, float]], tg: Union[TextIO, str], start_time: Optional[float] = None, end_time: Optional[float] = None, tier_name: str = config.DEFT_TEXTGRID_TIER_NAME, point_tier: Optional[bool] = None, precision: int = config.DEFT_FLOAT_PRINT_PRECISION, ) -> None: """Write a transcription as a TextGrid file TextGrid is the transcription format of `Praat <https://www.fon.hum.uva.nl/praat/>`_. This function saves `transcript` as a tier within a TextGrid file. Parameters ---------- transcript The transcription to write. Contains triples ``tok, start, end``, where `tok` is the token, `start` is its start time, and `end` is its end time. `transcript` must be non-empty. tg The file to write. Will open if `tg` is a path. start_time The start time of the recording (in seconds). If not specified, it will be inferred from the minimum start time of the intervals in `transcript`. end_time The end time of the recording (in seconds). If not specified, it will be inferred from the maximum end time of the intervals in `transcript`. tier_name What name to save the tier with. point_tier Whether to save as a point tier (:obj:`True`) or an interval tier. If unset, the value is inferred to be a point tier if all segments are length 0 (within precision `precision`); an interval tier otherwise. precision The precision of floating-point values to save times with. """ if isinstance(tg, str): with open(tg, "w") as tg: return write_textgrid(transcript, tg, start_time, end_time, tier_name) transcript = list(transcript) if not len(transcript): raise ValueError(f"Will not write an empty transcript") tier_start_time = min(x[1] for x in transcript) tier_end_time = max(x[2] for x in transcript) if start_time is None: start_time = tier_start_time elif start_time > tier_start_time: raise ValueError( f"gave start_time={start_time} but an interval starts at " f"{tier_start_time}" ) if end_time is None: end_time = tier_end_time elif end_time < tier_end_time: raise ValueError( f"gave end_time={end_time} but an interval ends at {tier_end_time}" ) if point_tier is None: point_tier = all( f"{x[1]:0.{precision}f}" == f"{x[2]:0.{precision}f}" for x in transcript ) tier_type = "TextTier" if point_tier else "IntervalTier" # fmt: off tg.write( 'File type = "ooTextFile"\n' 'Object class = "TextGrid"\n' f"{start_time:0.{precision}f}\n" f"{end_time:0.{precision}f}\n" "<exists>\n" "1\n" f'"{tier_type}"\n' f'"{tier_name}"\n' f"{tier_start_time:0.{precision}f}\n" f"{tier_end_time:0.{precision}f}\n" f"{len(transcript)}\n" ) # fmt: on for tok, start, end in transcript: if point_tier: tg.write(f'{start:0.{precision}f}\n"{tok}"\n') else: tg.write(f'{start:0.{precision}f}\n{end:0.{precision}f}\n"{tok}"\n')
[docs] def transcript_to_token( transcript: Sequence[Union[str, Tuple[str, float, float]]], token2id: Optional[dict] = None, frame_shift_ms: Optional[float] = None, unk: Optional[Union[str, int]] = None, skip_frame_times: bool = False, ) -> torch.Tensor: r"""Convert a transcript to a token sequence This function converts `transcript` of length ``R`` to a long tensor `tok` of shape ``(R, 3)``, the latter suitable as a reference or hypothesis token sequence for an utterance of :class:`SpectDataSet`. An element of `transcript` can either be a ``token`` or a 3-tuple of ``(token, start, end)``. If `token2id` is not :obj:`None`, the token id is determined by checking ``token2id[token]``. If the token does not exist in `token2id` and `unk` is not :obj:`None`, the token will be replaced with `unk`. If `unk` is :obj:`None`, `token` will be used directly as the id. If `token2id` is not specified, `token` will be used directly as the identifier. If `frame_shift_ms` is specified, ``start`` and ``end`` are taken as the start and end times, in seconds, of the token, and will be converted to frames for `tok`. If `frame_shift_ms` is unspecified, ``start`` and ``end`` are assumed to already be frame times. If ``start`` and ``end`` were unspecified, values of ``-1``, representing unknown, will be inserted into ``tok[r, 1:]`` Parameters ---------- transcript token2id frame_shift_ms unk The out-of-vocabulary token, if specified. If `unk` exists in `token2id`, the ``token2id[unk]`` will be used as the out-of-vocabulary identifier. If ``token2id[unk]`` does not exist, `unk` will be assumed to be the identifier already. If `token2id` is :obj:`None`, `unk` has no effect. skip_frame_times If :obj:`True`, `tok` will be of shape ``(R,)`` and contain only the token ids. Suitable for :class:`BitextDataSet`. Returns ------- tok : torch.Tensor Warnings -------- The frame index bounds inferred using `frame_shift_ms` should not be used directly in evaluation. See the below note. Notes ----- If you are dealing with raw audio, each "frame" is just a sample. The appropriate value for `frame_shift_ms` is ``1000 / sample_rate_hz`` (since there are ``sample_rate_hz / 1000`` samples per millisecond). Converting to frame indices from start and end times follows an overly-simplistic equation. Letting :math:`(s_s, e_s)` be the start and end times in seconds, :math:`(s_f, e_f)` be the corresponding start and end frames, :math:`\Delta` be the frame shift in milliseconds, and :math:`I[\cdot]` be the indicator function. Then .. math:: s_f = floor\left(\frac{1000s_s}{\Delta}\right) \\ e_f = \max\left(s_s + I[s_s = e_s], round\left(\frac{1000e_s}{\Delta}\right)\right) For a given token index, ``tok[r, 1] = s_f`` and ``tok[r, 2] = e_f``. ``tok[r, 1]`` is supposed to be the inclusive start frame of the segment and ``tok[r, 2]`` the exclusive end frame. :math:`(s_f, e_f)` fail to be these on two accounts. First, they do not consider the frame length. First, while frames may be spaced :math:`\Delta` milliseconds apart, they will usually be overlapping. Because of this overlap, the coefficients of frames :math:`s_f - 1` and :math:`e_f` may be in part dependent on the audio samples within the segment. Second, ignoring frame length, :math:`e_f = ceil(1000e_s/\Delta)` would be more appropriate for an exclusive upper bound. However, :mod:`pydrobert.speech.compute` (and other, mainstream feature computation packages), the total number of frames in the utterance is calculated as :math:`T_f = ceil(1000T_s/\Delta)`, where :math:`T_s` is the length of the utterance in seconds. The above equation ensures :math:`\max(e_f) \leq T_f`, which is a neccessary criterion for a valid :class:`SpectDataSet` (see :func:`validate_spec_data_set`). Accounting for both of these assumptions would involve computing the support of each existing frame in seconds and intersecting that with the provided interval in seconds. As such, the derived frame bounds should not be used for an official evaluation. This function should suffice for most training objectives, however. """ if token2id is not None and unk in token2id: unk = token2id[unk] tok_size = (len(transcript),) if not skip_frame_times: tok_size = tok_size + (3,) tok = torch.empty(tok_size, dtype=torch.long) for i, token in enumerate(transcript): start = end = -1 try: if len(token) == 3 and np.isreal(token[1]) and np.isreal(token[2]): token, start, end = token if frame_shift_ms: if start == end: start = end = (1000 * start) // frame_shift_ms else: start = (1000 * start) // frame_shift_ms end = (1000 * end + 0.5 * frame_shift_ms) // frame_shift_ms end = max(end, start + 1) else: start, end = int(start), int(end) except TypeError: pass if token2id is None: id_ = token else: id_ = token2id.get(token, token if unk is None else unk) if skip_frame_times: tok[i] = id_ else: tok[i, 0] = id_ tok[i, 1] = start tok[i, 2] = end return tok
[docs] def token_to_transcript( ref: torch.Tensor, id2token: Optional[Dict[int, str]] = None, frame_shift_ms: Optional[float] = None, ) -> List[Union[str, int, Tuple[Union[str, int], float, float]]]: """Convert a token sequence to a transcript The inverse operation of :func:`transcript_to_token`. Parameters ---------- ref A long tensor either of shape ``(R, 3)`` with segmentation info or ``(R, 1)`` or ``(R,)`` without id2token frame_shift_ms Returns ------- transcript Warnings -------- The time interval inferred using `frame_shift_ms` is unlikely to be perfectly correct. See the note in :func:`transcript_to_token` for more details about the ambiguity in converting between seconds and frames. """ transcript = [] for tup in ref: start = end = -1 if tup.ndim: id_ = tup[0].item() if tup.numel() == 3: start = tup[1].item() end = tup[2].item() else: id_ = tup.item() token = id2token.get(id_, id_) if id2token is not None else id_ if start == -1 or end == -1: transcript.append(token) else: if frame_shift_ms: start = start * frame_shift_ms / 1000 end = end * frame_shift_ms / 1000 transcript.append((token, start, end)) return transcript