File size: 2,082 Bytes
1582553
 
 
 
 
 
 
 
 
f16709f
1582553
 
 
 
 
 
 
 
 
 
 
 
 
 
f16709f
1582553
 
 
 
 
 
 
 
 
 
 
 
 
f16709f
1582553
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from pathlib import Path
import numpy as np
import tensorflow as tf
from iSparrow.sparrow_model_base import ModelBase
import pandas as pd


class Model(ModelBase):

    def __init__(self, model_path: str, num_threads: int = 1, **kwargs):
        """
        __init__ Create a new Model instance using the google perch model.

        Args:
            model_path (str): Path to the model file to load from disk
            num_threads (int): The number of threads used for inference. Currently not used for this model.
        """
        labels_path = str(Path(model_path) / "labels.txt")

        model_path = str(Path(model_path) / "saved_model.pb")

        self.class_mask = None  # used later

        super().__init__(
            "google_perch", model_path, labels_path, num_threads=num_threads, **kwargs
        )  # num_threads doesn't do anything here.

    def predict(self, data: np.array):
        """
        predict Make inference about the bird species for the preprocessed data passed to this function as arguments.

        Args:
            data (np.array): list of preprocessed data chunks
        Returns:
            list: List of (label, inferred_probability)
        """

        # README: this should be parallelized??
        logits, _ = self.model.infer_tf(
            np.array(
                [
                    data,
                ]
            )
        )

        results = tf.nn.softmax(logits).numpy()
        return results

    @classmethod
    def from_cfg(cls, sparrow_folder: str, cfg: dict):
        """
        from_cfg Create a new instance from a dictionary containing keyword arguments. Usually loaded from a config file.

        Args:
            sparrow_dir (str): Installation directory of the Sparrow package
            cfg (dict): Dictionary containing the keyword arguments

        Returns:
            Model: New model instance created with the supplied kwargs.
        """

        cfg["model_path"] = str(
            Path(sparrow_folder) / Path("models") / Path(cfg["model_path"])
        )

        return cls(**cfg)