ESPHome  2024.4.2
micro_wake_word.cpp
Go to the documentation of this file.
1 #include "micro_wake_word.h"
2 
8 //
9 #ifndef CLANG_TIDY
10 
11 #ifdef USE_ESP_IDF
12 
13 #include "esphome/core/hal.h"
14 #include "esphome/core/helpers.h"
15 #include "esphome/core/log.h"
16 
18 
19 #include <tensorflow/lite/core/c/common.h>
20 #include <tensorflow/lite/micro/micro_interpreter.h>
21 #include <tensorflow/lite/micro/micro_mutable_op_resolver.h>
22 
23 #include <cmath>
24 
25 namespace esphome {
26 namespace micro_wake_word {
27 
28 static const char *const TAG = "micro_wake_word";
29 
30 static const size_t SAMPLE_RATE_HZ = 16000; // 16 kHz
31 static const size_t BUFFER_LENGTH = 500; // 0.5 seconds
32 static const size_t BUFFER_SIZE = SAMPLE_RATE_HZ / 1000 * BUFFER_LENGTH;
33 static const size_t INPUT_BUFFER_SIZE = 32 * SAMPLE_RATE_HZ / 1000; // 32ms * 16kHz / 1000ms
34 
36 
37 static const LogString *micro_wake_word_state_to_string(State state) {
38  switch (state) {
39  case State::IDLE:
40  return LOG_STR("IDLE");
42  return LOG_STR("START_MICROPHONE");
44  return LOG_STR("STARTING_MICROPHONE");
46  return LOG_STR("DETECTING_WAKE_WORD");
48  return LOG_STR("STOP_MICROPHONE");
50  return LOG_STR("STOPPING_MICROPHONE");
51  default:
52  return LOG_STR("UNKNOWN");
53  }
54 }
55 
57  ESP_LOGCONFIG(TAG, "microWakeWord:");
58  ESP_LOGCONFIG(TAG, " Wake Word: %s", this->get_wake_word().c_str());
59  ESP_LOGCONFIG(TAG, " Probability cutoff: %.3f", this->probability_cutoff_);
60  ESP_LOGCONFIG(TAG, " Sliding window size: %d", this->sliding_window_average_size_);
61 }
62 
64  ESP_LOGCONFIG(TAG, "Setting up microWakeWord...");
65 
66  if (!this->initialize_models()) {
67  ESP_LOGE(TAG, "Failed to initialize models");
68  this->mark_failed();
69  return;
70  }
71 
73  this->input_buffer_ = allocator.allocate(INPUT_BUFFER_SIZE * sizeof(int16_t));
74  if (this->input_buffer_ == nullptr) {
75  ESP_LOGW(TAG, "Could not allocate input buffer");
76  this->mark_failed();
77  return;
78  }
79 
80  this->ring_buffer_ = RingBuffer::create(BUFFER_SIZE * sizeof(int16_t));
81  if (this->ring_buffer_ == nullptr) {
82  ESP_LOGW(TAG, "Could not allocate ring buffer");
83  this->mark_failed();
84  return;
85  }
86 
87  ESP_LOGCONFIG(TAG, "Micro Wake Word initialized");
88 }
89 
91  size_t bytes_read = this->microphone_->read(this->input_buffer_, INPUT_BUFFER_SIZE * sizeof(int16_t));
92  if (bytes_read == 0) {
93  return 0;
94  }
95 
96  size_t bytes_free = this->ring_buffer_->free();
97 
98  if (bytes_free < bytes_read) {
99  ESP_LOGW(TAG,
100  "Not enough free bytes in ring buffer to store incoming audio data (free bytes=%d, incoming bytes=%d). "
101  "Resetting the ring buffer. Wake word detection accuracy will be reduced.",
102  bytes_free, bytes_read);
103 
104  this->ring_buffer_->reset();
105  }
106 
107  return this->ring_buffer_->write((void *) this->input_buffer_, bytes_read);
108 }
109 
111  switch (this->state_) {
112  case State::IDLE:
113  break;
115  ESP_LOGD(TAG, "Starting Microphone");
116  this->microphone_->start();
118  this->high_freq_.start();
119  break;
121  if (this->microphone_->is_running()) {
123  }
124  break;
126  this->read_microphone_();
127  if (this->detect_wake_word_()) {
128  ESP_LOGD(TAG, "Wake Word Detected");
129  this->detected_ = true;
131  }
132  break;
134  ESP_LOGD(TAG, "Stopping Microphone");
135  this->microphone_->stop();
137  this->high_freq_.stop();
138  break;
140  if (this->microphone_->is_stopped()) {
141  this->set_state_(State::IDLE);
142  if (this->detected_) {
143  this->detected_ = false;
145  }
146  }
147  break;
148  }
149 }
150 
152  if (this->is_failed()) {
153  ESP_LOGW(TAG, "Wake word component is marked as failed. Please check setup logs");
154  return;
155  }
156  if (this->state_ != State::IDLE) {
157  ESP_LOGW(TAG, "Wake word is already running");
158  return;
159  }
161 }
162 
164  if (this->state_ == State::IDLE) {
165  ESP_LOGW(TAG, "Wake word is already stopped");
166  return;
167  }
168  if (this->state_ == State::STOPPING_MICROPHONE) {
169  ESP_LOGW(TAG, "Wake word is already stopping");
170  return;
171  }
173 }
174 
176  ESP_LOGD(TAG, "State changed from %s to %s", LOG_STR_ARG(micro_wake_word_state_to_string(this->state_)),
177  LOG_STR_ARG(micro_wake_word_state_to_string(state)));
178  this->state_ = state;
179 }
180 
185 
186  this->streaming_tensor_arena_ = arena_allocator.allocate(STREAMING_MODEL_ARENA_SIZE);
187  if (this->streaming_tensor_arena_ == nullptr) {
188  ESP_LOGE(TAG, "Could not allocate the streaming model's tensor arena.");
189  return false;
190  }
191 
192  this->streaming_var_arena_ = arena_allocator.allocate(STREAMING_MODEL_VARIABLE_ARENA_SIZE);
193  if (this->streaming_var_arena_ == nullptr) {
194  ESP_LOGE(TAG, "Could not allocate the streaming model variable's tensor arena.");
195  return false;
196  }
197 
198  this->preprocessor_tensor_arena_ = arena_allocator.allocate(PREPROCESSOR_ARENA_SIZE);
199  if (this->preprocessor_tensor_arena_ == nullptr) {
200  ESP_LOGE(TAG, "Could not allocate the audio preprocessor model's tensor arena.");
201  return false;
202  }
203 
204  this->new_features_data_ = features_allocator.allocate(PREPROCESSOR_FEATURE_SIZE);
205  if (this->new_features_data_ == nullptr) {
206  ESP_LOGE(TAG, "Could not allocate the audio features buffer.");
207  return false;
208  }
209 
210  this->preprocessor_audio_buffer_ = audio_samples_allocator.allocate(SAMPLE_DURATION_COUNT);
211  if (this->preprocessor_audio_buffer_ == nullptr) {
212  ESP_LOGE(TAG, "Could not allocate the audio preprocessor's buffer.");
213  return false;
214  }
215 
216  this->preprocessor_model_ = tflite::GetModel(G_AUDIO_PREPROCESSOR_INT8_TFLITE);
217  if (this->preprocessor_model_->version() != TFLITE_SCHEMA_VERSION) {
218  ESP_LOGE(TAG, "Wake word's audio preprocessor model's schema is not supported");
219  return false;
220  }
221 
222  this->streaming_model_ = tflite::GetModel(this->model_start_);
223  if (this->streaming_model_->version() != TFLITE_SCHEMA_VERSION) {
224  ESP_LOGE(TAG, "Wake word's streaming model's schema is not supported");
225  return false;
226  }
227 
228  static tflite::MicroMutableOpResolver<18> preprocessor_op_resolver;
229  static tflite::MicroMutableOpResolver<17> streaming_op_resolver;
230 
231  if (!this->register_preprocessor_ops_(preprocessor_op_resolver))
232  return false;
233  if (!this->register_streaming_ops_(streaming_op_resolver))
234  return false;
235 
236  tflite::MicroAllocator *ma =
237  tflite::MicroAllocator::Create(this->streaming_var_arena_, STREAMING_MODEL_VARIABLE_ARENA_SIZE);
238  this->mrv_ = tflite::MicroResourceVariables::Create(ma, 15);
239 
240  static tflite::MicroInterpreter static_preprocessor_interpreter(
241  this->preprocessor_model_, preprocessor_op_resolver, this->preprocessor_tensor_arena_, PREPROCESSOR_ARENA_SIZE);
242 
243  static tflite::MicroInterpreter static_streaming_interpreter(this->streaming_model_, streaming_op_resolver,
245  STREAMING_MODEL_ARENA_SIZE, this->mrv_);
246 
247  this->preprocessor_interperter_ = &static_preprocessor_interpreter;
248  this->streaming_interpreter_ = &static_streaming_interpreter;
249 
250  // Allocate tensors for each models.
251  if (this->preprocessor_interperter_->AllocateTensors() != kTfLiteOk) {
252  ESP_LOGE(TAG, "Failed to allocate tensors for the audio preprocessor");
253  return false;
254  }
255  if (this->streaming_interpreter_->AllocateTensors() != kTfLiteOk) {
256  ESP_LOGE(TAG, "Failed to allocate tensors for the streaming model");
257  return false;
258  }
259 
260  // Verify input tensor matches expected values
261  TfLiteTensor *input = this->streaming_interpreter_->input(0);
262  if ((input->dims->size != 3) || (input->dims->data[0] != 1) || (input->dims->data[0] != 1) ||
263  (input->dims->data[1] != 1) || (input->dims->data[2] != PREPROCESSOR_FEATURE_SIZE)) {
264  ESP_LOGE(TAG, "Wake word detection model tensor input dimensions is not 1x1x%u", input->dims->data[2]);
265  return false;
266  }
267 
268  if (input->type != kTfLiteInt8) {
269  ESP_LOGE(TAG, "Wake word detection model tensor input is not int8.");
270  return false;
271  }
272 
273  // Verify output tensor matches expected values
274  TfLiteTensor *output = this->streaming_interpreter_->output(0);
275  if ((output->dims->size != 2) || (output->dims->data[0] != 1) || (output->dims->data[1] != 1)) {
276  ESP_LOGE(TAG, "Wake word detection model tensor output dimensions is not 1x1.");
277  }
278 
279  if (output->type != kTfLiteUInt8) {
280  ESP_LOGE(TAG, "Wake word detection model tensor input is not uint8.");
281  return false;
282  }
283 
285 
286  return true;
287 }
288 
290  // Retrieve strided audio samples
291  int16_t *audio_samples = nullptr;
292  if (!this->stride_audio_samples_(&audio_samples)) {
293  return false;
294  }
295 
296  // Compute the features for the newest audio samples
297  if (!this->generate_single_feature_(audio_samples, SAMPLE_DURATION_COUNT, this->new_features_data_)) {
298  return false;
299  }
300 
301  return true;
302 }
303 
305  TfLiteTensor *input = this->streaming_interpreter_->input(0);
306 
307  size_t bytes_to_copy = input->bytes;
308 
309  memcpy((void *) (tflite::GetTensorData<int8_t>(input)), (const void *) (this->new_features_data_), bytes_to_copy);
310 
311  uint32_t prior_invoke = millis();
312 
313  TfLiteStatus invoke_status = this->streaming_interpreter_->Invoke();
314  if (invoke_status != kTfLiteOk) {
315  ESP_LOGW(TAG, "Streaming Interpreter Invoke failed");
316  return false;
317  }
318 
319  ESP_LOGV(TAG, "Streaming Inference Latency=%u ms", (millis() - prior_invoke));
320 
321  TfLiteTensor *output = this->streaming_interpreter_->output(0);
322 
323  return static_cast<float>(output->data.uint8[0]) / 255.0;
324 }
325 
327  // Preprocess the newest audio samples into features
328  if (!this->update_features_()) {
329  return false;
330  }
331 
332  // Perform inference
333  float streaming_prob = this->perform_streaming_inference_();
334 
335  // Add the most recent probability to the sliding window
336  this->recent_streaming_probabilities_[this->last_n_index_] = streaming_prob;
337  ++this->last_n_index_;
338  if (this->last_n_index_ == this->sliding_window_average_size_)
339  this->last_n_index_ = 0;
340 
341  float sum = 0.0;
342  for (auto &prob : this->recent_streaming_probabilities_) {
343  sum += prob;
344  }
345 
346  float sliding_window_average = sum / static_cast<float>(this->sliding_window_average_size_);
347 
348  // Ensure we have enough samples since the last positive detection
349  this->ignore_windows_ = std::min(this->ignore_windows_ + 1, 0);
350  if (this->ignore_windows_ < 0) {
351  return false;
352  }
353 
354  // Detect the wake word if the sliding window average is above the cutoff
355  if (sliding_window_average > this->probability_cutoff_) {
356  this->ignore_windows_ = -MIN_SLICES_BEFORE_DETECTION;
357  for (auto &prob : this->recent_streaming_probabilities_) {
358  prob = 0;
359  }
360 
361  ESP_LOGD(TAG, "Wake word sliding average probability is %.3f and most recent probability is %.3f",
362  sliding_window_average, streaming_prob);
363  return true;
364  }
365 
366  return false;
367 }
368 
370  this->sliding_window_average_size_ = size;
372 }
373 
375  size_t available = this->ring_buffer_->available();
376 
377  return available > (NEW_SAMPLES_TO_GET * sizeof(int16_t));
378 }
379 
380 bool MicroWakeWord::stride_audio_samples_(int16_t **audio_samples) {
381  if (!this->slice_available_()) {
382  return false;
383  }
384 
385  // Copy the last 320 bytes (160 samples over 10 ms) from the audio buffer to the start of the audio buffer
386  memcpy((void *) (this->preprocessor_audio_buffer_), (void *) (this->preprocessor_audio_buffer_ + NEW_SAMPLES_TO_GET),
387  HISTORY_SAMPLES_TO_KEEP * sizeof(int16_t));
388 
389  // Copy 640 bytes (320 samples over 20 ms) from the ring buffer into the audio buffer offset 320 bytes (160 samples
390  // over 10 ms)
391  size_t bytes_read = this->ring_buffer_->read((void *) (this->preprocessor_audio_buffer_ + HISTORY_SAMPLES_TO_KEEP),
392  NEW_SAMPLES_TO_GET * sizeof(int16_t), pdMS_TO_TICKS(200));
393 
394  if (bytes_read == 0) {
395  ESP_LOGE(TAG, "Could not read data from Ring Buffer");
396  } else if (bytes_read < NEW_SAMPLES_TO_GET * sizeof(int16_t)) {
397  ESP_LOGD(TAG, "Partial Read of Data by Model");
398  ESP_LOGD(TAG, "Could only read %d bytes when required %d bytes ", bytes_read,
399  (int) (NEW_SAMPLES_TO_GET * sizeof(int16_t)));
400  return false;
401  }
402 
403  *audio_samples = this->preprocessor_audio_buffer_;
404  return true;
405 }
406 
407 bool MicroWakeWord::generate_single_feature_(const int16_t *audio_data, const int audio_data_size,
408  int8_t feature_output[PREPROCESSOR_FEATURE_SIZE]) {
409  TfLiteTensor *input = this->preprocessor_interperter_->input(0);
410  TfLiteTensor *output = this->preprocessor_interperter_->output(0);
411  std::copy_n(audio_data, audio_data_size, tflite::GetTensorData<int16_t>(input));
412 
413  if (this->preprocessor_interperter_->Invoke() != kTfLiteOk) {
414  ESP_LOGE(TAG, "Failed to preprocess audio for local wake word.");
415  return false;
416  }
417  std::memcpy(feature_output, tflite::GetTensorData<int8_t>(output), PREPROCESSOR_FEATURE_SIZE * sizeof(int8_t));
418 
419  return true;
420 }
421 
422 bool MicroWakeWord::register_preprocessor_ops_(tflite::MicroMutableOpResolver<18> &op_resolver) {
423  if (op_resolver.AddReshape() != kTfLiteOk)
424  return false;
425  if (op_resolver.AddCast() != kTfLiteOk)
426  return false;
427  if (op_resolver.AddStridedSlice() != kTfLiteOk)
428  return false;
429  if (op_resolver.AddConcatenation() != kTfLiteOk)
430  return false;
431  if (op_resolver.AddMul() != kTfLiteOk)
432  return false;
433  if (op_resolver.AddAdd() != kTfLiteOk)
434  return false;
435  if (op_resolver.AddDiv() != kTfLiteOk)
436  return false;
437  if (op_resolver.AddMinimum() != kTfLiteOk)
438  return false;
439  if (op_resolver.AddMaximum() != kTfLiteOk)
440  return false;
441  if (op_resolver.AddWindow() != kTfLiteOk)
442  return false;
443  if (op_resolver.AddFftAutoScale() != kTfLiteOk)
444  return false;
445  if (op_resolver.AddRfft() != kTfLiteOk)
446  return false;
447  if (op_resolver.AddEnergy() != kTfLiteOk)
448  return false;
449  if (op_resolver.AddFilterBank() != kTfLiteOk)
450  return false;
451  if (op_resolver.AddFilterBankSquareRoot() != kTfLiteOk)
452  return false;
453  if (op_resolver.AddFilterBankSpectralSubtraction() != kTfLiteOk)
454  return false;
455  if (op_resolver.AddPCAN() != kTfLiteOk)
456  return false;
457  if (op_resolver.AddFilterBankLog() != kTfLiteOk)
458  return false;
459 
460  return true;
461 }
462 
463 bool MicroWakeWord::register_streaming_ops_(tflite::MicroMutableOpResolver<17> &op_resolver) {
464  if (op_resolver.AddCallOnce() != kTfLiteOk)
465  return false;
466  if (op_resolver.AddVarHandle() != kTfLiteOk)
467  return false;
468  if (op_resolver.AddReshape() != kTfLiteOk)
469  return false;
470  if (op_resolver.AddReadVariable() != kTfLiteOk)
471  return false;
472  if (op_resolver.AddStridedSlice() != kTfLiteOk)
473  return false;
474  if (op_resolver.AddConcatenation() != kTfLiteOk)
475  return false;
476  if (op_resolver.AddAssignVariable() != kTfLiteOk)
477  return false;
478  if (op_resolver.AddConv2D() != kTfLiteOk)
479  return false;
480  if (op_resolver.AddMul() != kTfLiteOk)
481  return false;
482  if (op_resolver.AddAdd() != kTfLiteOk)
483  return false;
484  if (op_resolver.AddMean() != kTfLiteOk)
485  return false;
486  if (op_resolver.AddFullyConnected() != kTfLiteOk)
487  return false;
488  if (op_resolver.AddLogistic() != kTfLiteOk)
489  return false;
490  if (op_resolver.AddQuantize() != kTfLiteOk)
491  return false;
492  if (op_resolver.AddDepthwiseConv2D() != kTfLiteOk)
493  return false;
494  if (op_resolver.AddAveragePool2D() != kTfLiteOk)
495  return false;
496  if (op_resolver.AddMaxPool2D() != kTfLiteOk)
497  return false;
498 
499  return true;
500 }
501 
502 } // namespace micro_wake_word
503 } // namespace esphome
504 
505 #endif // USE_ESP_IDF
506 
507 #endif // CLANG_TIDY
tflite::MicroInterpreter * streaming_interpreter_
const float AFTER_CONNECTION
For components that should be initialized after a data connection (API/MQTT) is connected.
Definition: component.cpp:27
bool detect_wake_word_()
Detects if wake word has been said.
bool register_preprocessor_ops_(tflite::MicroMutableOpResolver< 18 > &op_resolver)
Returns true if successfully registered the preprocessor&#39;s TensorFlow operations. ...
Trigger< std::string > * wake_word_detected_trigger_
std::unique_ptr< RingBuffer > ring_buffer_
std::vector< float > recent_streaming_probabilities_
An STL allocator that uses SPI RAM.
Definition: helpers.h:645
HighFrequencyLoopRequester high_freq_
bool stride_audio_samples_(int16_t **audio_samples)
Strides the audio samples by keeping the last 10 ms of the previous slice.
uint32_t IRAM_ATTR HOT millis()
Definition: core.cpp:25
void trigger(Ts... x)
Inform the parent automation that the event has triggered.
Definition: automation.h:95
tflite::MicroInterpreter * preprocessor_interperter_
bool update_features_()
Shifts previous feature slices over by one and generates a new slice of features. ...
float perform_streaming_inference_()
Performs inference over the most recent feature slice with the streaming model.
microphone::Microphone * microphone_
void start()
Start running the loop continuously.
Definition: helpers.cpp:547
bool register_streaming_ops_(tflite::MicroMutableOpResolver< 17 > &op_resolver)
Returns true if successfully registered the streaming model&#39;s TensorFlow operations.
void stop()
Stop running the loop continuously.
Definition: helpers.cpp:553
virtual size_t read(int16_t *buf, size_t len)=0
virtual void mark_failed()
Mark this component as failed.
Definition: component.cpp:118
bool generate_single_feature_(const int16_t *audio_data, int audio_data_size, int8_t feature_output[PREPROCESSOR_FEATURE_SIZE])
Generates features from audio samples.
This is a workaround until we can figure out a way to get the tflite-micro idf component code availab...
Definition: a01nyub.cpp:7
bool slice_available_()
Returns true if there are enough audio samples in the buffer to generate another slice of features...
static std::unique_ptr< RingBuffer > create(size_t len)
Definition: ring_buffer.cpp:14
const unsigned char G_AUDIO_PREPROCESSOR_INT8_TFLITE[]
bool state
Definition: fan.h:34
tflite::MicroResourceVariables * mrv_