Skip to content
68 changes: 34 additions & 34 deletions src/htm/algorithms/SDRClassifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,11 @@ void Classifier::initialize(const Real alpha)
}


PDF Classifier::infer(const SDR & pattern)
{
// Check input dimensions, or if this is the first time the Classifier has
// been used then initialize it with the given SDR's dimensions.
PDF Classifier::infer(const SDR & pattern) const {
// Check input dimensions, or if this is the first time the Classifier is used and dimensions
// are unset, return zeroes.
if( dimensions_.empty() ) {
dimensions_ = pattern.dimensions;
while( weights_.size() < pattern.size ) {
weights_.push_back( vector<Real>( numCategories_, 0.0f ));
}
return PDF(numCategories_, 0.0f); //empty
} else if( pattern.dimensions != dimensions_ ) {
stringstream err_msg;
err_msg << "Classifier input SDR.dimensions mismatch: previously given SDR with dimensions ( ";
Expand Down Expand Up @@ -81,6 +77,15 @@ PDF Classifier::infer(const SDR & pattern)

void Classifier::learn(const SDR &pattern, const vector<UInt> &categoryIdxList)
{
// If this is the first time the Classifier is being used, weights are empty,
// so we set the dimensions to that of the input `pattern`
if( dimensions_.empty() ) {
dimensions_ = pattern.dimensions;
while( weights_.size() < pattern.size ) {
const auto initialEmptyWeights = PDF( numCategories_, 0.0f );
weights_.push_back( initialEmptyWeights );
}
}
// Check if this is a new category & resize the weights table to hold it.
const auto maxCategoryIdx = *max_element(categoryIdxList.begin(), categoryIdxList.end());
if( maxCategoryIdx >= numCategories_ ) {
Expand All @@ -93,7 +98,7 @@ void Classifier::learn(const SDR &pattern, const vector<UInt> &categoryIdxList)
}

// Compute errors and update weights.
const vector<Real> error = calculateError_(categoryIdxList, pattern);
const auto& error = calculateError_(categoryIdxList, pattern);
for( const auto& bit : pattern.getSparse() ) {
for(size_t i = 0u; i < numCategories_; i++) {
weights_[bit][i] += alpha_ * error[i];
Expand All @@ -103,9 +108,8 @@ void Classifier::learn(const SDR &pattern, const vector<UInt> &categoryIdxList)


// Helper function to compute the error signal in learning.
std::vector<Real> Classifier::calculateError_(
const std::vector<UInt> &categoryIdxList, const SDR &pattern)
{
std::vector<Real64> Classifier::calculateError_(const std::vector<UInt> &categoryIdxList,
const SDR &pattern) const {
// compute predicted likelihoods
auto likelihoods = infer(pattern);

Expand Down Expand Up @@ -165,13 +169,12 @@ void Predictor::reset() {
}


Predictions Predictor::infer(const UInt recordNum, const SDR &pattern)
{
updateHistory_( recordNum, pattern );
Predictions Predictor::infer(const UInt recordNum, const SDR &pattern) const {
checkMonotonic_(recordNum);
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ctrl-z-9000-times please have a look on the last 1,(2) commits, I intended to make infer const, as it imho should have been. For Classifier that was easy, for Predictor I had to remove the updateHistory_ from infer (no tests visibly broken).

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does it matter if infer is constant? I doubt it will have any performance impact, and I don't think this will prevent any programming mistakes.

Consider the following chain of events:

infer( t, SDR-A ) -> PDF
learn( t+1, SDR-B, Labels )

The method updateHistory stores the given SDR inside of the predictor. Previously the call to learn would have associated SDR-A with Labels since that SDR was given to infer with. Now that will not happen.

Also, these changes allow the timestamps to the infer method to go backwards.

infer( t + 2, ... )
infer( t, ... )

I'm not saying the new behavior is wrong, just that it changed.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I doubt it will have any performance impact, and I don't think this will prevent any programming mistakes.

right, there won't be any performance gains, and the behavior has changed. I think it adds (expected) semantics, and better separates responsibilities of those 2 functions:

  • you know only what you learn()
  • and you can infer() const anytime without any worry of changing the state (comes with loosened requirement for monotonic timestamps)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, that makes sense +1


Predictions result;
for( const auto step : steps_ ) {
result[step] = classifiers_[step].infer( pattern );
result[step] = classifiers_.at(step).infer( pattern );
}
return result;
}
Expand All @@ -180,7 +183,18 @@ Predictions Predictor::infer(const UInt recordNum, const SDR &pattern)
void Predictor::learn(const UInt recordNum, const SDR &pattern,
const std::vector<UInt> &bucketIdxList)
{
updateHistory_( recordNum, pattern );
checkMonotonic_(recordNum);

// Update pattern history if this is a new record.
const UInt lastRecordNum = recordNumHistory_.empty() ? -1 : recordNumHistory_.back();
if (recordNumHistory_.size() == 0u || recordNum > lastRecordNum) {
patternHistory_.emplace_back( pattern );
recordNumHistory_.push_back(recordNum);
if (patternHistory_.size() > steps_.back() + 1u) {
patternHistory_.pop_front();
recordNumHistory_.pop_front();
}
}

// Iterate through all recently given inputs, starting from the furthest in the past.
auto pastPattern = patternHistory_.begin();
Expand All @@ -197,24 +211,10 @@ void Predictor::learn(const UInt recordNum, const SDR &pattern,
}


void Predictor::updateHistory_(const UInt recordNum, const SDR & pattern)
{
void Predictor::checkMonotonic_(const UInt recordNum) const {
// Ensure that recordNum increases monotonically.
UInt lastRecordNum = -1;
const UInt lastRecordNum = recordNumHistory_.empty() ? -1 : recordNumHistory_.back();
if( not recordNumHistory_.empty() ) {
lastRecordNum = recordNumHistory_.back();
if (recordNum < lastRecordNum) {
NTA_THROW << "The record number must increase monotonically.";
}
NTA_CHECK(recordNum >= lastRecordNum) << "The record number must increase monotonically.";
}

// Update pattern history if this is a new record.
if (recordNumHistory_.size() == 0u || recordNum > lastRecordNum) {
patternHistory_.emplace_back( pattern );
recordNumHistory_.push_back(recordNum);
if (patternHistory_.size() > steps_.back() + 1u) {
patternHistory_.pop_front();
recordNumHistory_.pop_front();
}
}
}
30 changes: 21 additions & 9 deletions src/htm/algorithms/SDRClassifier.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@

/** @file
* Definitions for the SDR Classifier & Predictor.
*
* `Classifier` learns mapping from SDR->input value (encoder's output). This is used when you need to "explain" the HTM network back to real-world, ie. mapping SDRs
* back to digits in MNIST digit classification task.
*
* `Predictor` has similar functionality for time-sequences where you want to "predict" N-steps ahead and then return real-world value.
* Internally it uses (several) Classifiers, and in nupic.core this used to be part for SDRClassifier, for htm.core this is a separate class Predictor.
*
*/

#ifndef NTA_SDR_CLASSIFIER_HPP
Expand All @@ -43,7 +50,7 @@ namespace htm {
*
* See also: https://en.wikipedia.org/wiki/Probability_distribution
*/
using PDF = std::vector<Real>;
using PDF = std::vector<Real64>; //Real64 (not Real/float) must be used here, otherwise precision is lost and Predictor never reaches sufficient results.

/**
* Returns the category with the greatest probablility.
Expand Down Expand Up @@ -115,7 +122,7 @@ class Classifier : public Serializable
* @returns: The Probablility Distribution Function (PDF) of the categories.
* This is indexed by the category label.
*/
PDF infer(const SDR & pattern);
PDF infer(const SDR & pattern) const;

/**
* Learn from example data.
Expand Down Expand Up @@ -147,12 +154,13 @@ class Classifier : public Serializable
/**
* 2D map used to store the data.
* Use as: weights_[ input-bit ][ category-index ]
* Real64 (not just Real) so the computations do not lose precision.
*/
std::vector<std::vector<Real>> weights_;
std::vector<std::vector<Real64>> weights_;

// Helper function to compute the error signal for learning.
std::vector<Real> calculateError_(const std::vector<UInt> &bucketIdxList,
const SDR &pattern);
std::vector<Real64> calculateError_(const std::vector<UInt> &bucketIdxList,
const SDR &pattern) const;
};

/**
Expand All @@ -179,9 +187,11 @@ using Predictions = std::map<UInt, PDF>;
* This class handles missing datapoints.
*
* Compatibility Note: This class is the replacement for the old SDRClassifier.
* It no longer provides estimates of the actual value.
* It no longer provides estimates of the actual value. Instead, users can get a rough estimate
* from bucket-index. If more precision is needed, use more buckets in the encoder.
*
* Example Usage:
* ```
* // Predict 1 and 2 time steps into the future.
* // Make a sequence of 4 random SDRs. Each SDR has 1000 bits and 2% sparsity.
* vector<SDR> sequence( 4, { 1000 } );
Expand All @@ -208,6 +218,7 @@ using Predictions = std::map<UInt, PDF>;
* Predictions B = pred.infer( 1, sequence[1] );
* argmax( B[1] ) -> labels[2]
* argmax( B[2] ) -> labels[3]
* ```
*/
class Predictor : public Serializable
{
Expand Down Expand Up @@ -242,7 +253,7 @@ class Predictor : public Serializable
*
* @returns: A mapping from prediction step to PDF.
*/
Predictions infer(UInt recordNum, const SDR &pattern);
Predictions infer(const UInt recordNum, const SDR &pattern) const;

/**
* Learn from example data.
Expand All @@ -252,7 +263,8 @@ class Predictor : public Serializable
* @param pattern: The active input SDR.
* @param bucketIdxList: Vector of the current value bucket indices or categories.
*/
void learn(UInt recordNum, const SDR &pattern,
void learn(const UInt recordNum,
const SDR &pattern,
const std::vector<UInt> &bucketIdxList);

CerealAdapter;
Expand All @@ -276,7 +288,7 @@ class Predictor : public Serializable
// Stores the input pattern history, starting with the previous input.
std::deque<SDR> patternHistory_;
std::deque<UInt> recordNumHistory_;
void updateHistory_(UInt recordNum, const SDR & pattern);
void checkMonotonic_(UInt recordNum) const;

// One per prediction step
std::map<UInt, Classifier> classifiers_;
Expand Down