#ifndef __TRWS_H__ #define __TRWS_H__ #include #include #include #include #include "mrf.h" class TRWS : public MRF{ public: typedef double REAL; TRWS(int width, int height, int nLabels, EnergyFunction *eng); TRWS(int nPixels, int nLabels,EnergyFunction *eng); ~TRWS(); void setNeighbors(int /*pix1*/, int /*pix2*/, CostVal /*weight*/){printf("Not implemented"); exit(1);} Label getLabel(int pixel){return(m_answer[pixel]);}; void setLabel(int pixel,Label label){m_answer[pixel] = label;}; Label* getAnswerPtr(){return(m_answer);}; void clearAnswer(); void setParameters(int /*numParam*/, void * /*param*/){printf("No optional parameters to set"); exit(1);} EnergyVal smoothnessEnergy(); EnergyVal dataEnergy(); double lowerBound() { return (double)m_lowerBound; } // For general smoothness functions, this code tries to cache all function values in an array // for efficiency. To prevent this, call the following function before calling initialize(): void dontCacheSmoothnessCosts() {m_allocateArrayForSmoothnessCostFn = false;} protected: void setData(DataCostFn dcost); void setData(CostVal* data); void setSmoothness(SmoothCostGeneralFn cost); void setSmoothness(CostVal* V); void setSmoothness(int smoothExp,CostVal smoothMax, CostVal lambda); void setCues(CostVal* hCue, CostVal* vCue); void Allocate(); void initializeAlg(); void optimizeAlg(int nIterations); private: enum { NONE, L1, L2, FIXED_MATRIX, GENERAL, BINARY, } m_type; CostVal m_smoothMax; // used only if CostVal m_lambda; // m_type == L1 or m_type == L2 Label *m_answer; CostVal *m_V; // points to array of size nLabels^2 (if type==FIXED_MATRIX) or of size nEdges*nLabels^2 (if type==GENERAL) CostVal *m_D; CostVal *m_DBinary; // valid if type == BINARY CostVal *m_horzWeights; CostVal *m_vertWeights; CostVal *m_horzWeightsBinary; CostVal *m_vertWeightsBinary; DataCostFn m_dataFn; SmoothCostGeneralFn m_smoothFn; bool m_needToFreeV; bool m_needToFreeD; REAL* m_messages; // size of one message: N = 1 if m_type == BINARY, N = K otherwise // message between edges (x,y)-(x+1,y): m_messages+(2*x+2*y*m_width)*N // message between edges (x,y)-(x,y+1): m_messages+(2*x+2*y*m_width+1)*N int m_messageArraySizeInBytes; REAL m_lowerBound; void optimize_GRID_L1(int nIterations); void optimize_GRID_L2(int nIterations); void optimize_GRID_FIXED_MATRIX(int nIterations); void optimize_GRID_GENERAL(int nIterations); void optimize_GRID_BINARY(int nIterations); }; #endif /* __TRWS_H__ */