#include // defines various exceptions we throw #include // need stringstream for operator<< #include // needed for operator<< #include "matrix_expression.h" #include "matrix.h" #include "matrix_proxy.h" using namespace std; /***************************************************** * support for matrix_expression class ****************************************************/ // assignment operator supports syntax: A = B; // For generic expressions, this is only well defined if dimensions agree. matrix_expression& matrix_expression::operator=(const matrix_expression& other) { if (numRows() != other.numRows() || numColumns() != other.numColumns()) throw invalid_argument("Matrix dimensions must agree."); for (int r=0; r < numRows(); r++) for (int c=0; c < numColumns(); c++) (*this)(r,c) = other(r,c); return *this; } // Returns a 1x2 matrix describing the number of rows and columns respectively matrix matrix_expression::size() const { matrix result(1,2); result(0,0) = numRows(); result(0,1) = numColumns(); return result; } bool matrix_expression::operator==(const matrix_expression &other) const { if (numRows() != other.numRows() || numColumns() != other.numColumns()) return false; // not equivalent for (int r=0; r < numRows(); r++) for (int c=0; c < numColumns(); c++) if ((*this)(r,c) != other(r,c)) return false; // not equivalent return true; // by process of elimination, must be equivalent } bool matrix_expression::operator!=(const matrix_expression &other) const { return !(*this == other); // piggy-back on existing definition of == } // provides read-only access to a submatrix via a proxy const matrix_proxy matrix_expression::operator()(range rows, range cols) const { return matrix_proxy(*const_cast(this), rows, cols); } // provides write access to a submatrix as a proxy matrix_proxy matrix_expression::operator()(range rows, range cols) { return matrix_proxy(*this, rows, cols); } //----------------------------------------------- // addition //----------------------------------------------- // returns new matrix instance based on sum of two expressions matrix matrix_expression::operator+(const matrix_expression& other) const { if (numRows() != other.numRows() || numColumns() != other.numColumns()) throw invalid_argument("Matrix dimensions must agree."); matrix result(*this); // a new matrix instance for result for (int r=0; r < numRows(); r++) for (int c=0; c < numColumns(); c++) result(r,c) += other(r,c); return result; } // in-place addition with another matrix expression matrix_expression& matrix_expression::operator+=(const matrix_expression& other) { if (numRows() != other.numRows() || numColumns() != other.numColumns()) throw invalid_argument("Matrix dimensions must agree."); for (int r=0; r < numRows(); r++) for (int c=0; c < numColumns(); c++) (*this)(r,c) += other(r,c); return *this; } // returns new matrix instance based on element-wise addition matrix matrix_expression::operator+(double scalar) const { matrix result(*this); for (int r=0; r < numRows(); r++) for (int c=0; c < numColumns(); c++) result(r,c) += scalar; return result; } // in-place element-wise addition with a scalar matrix_expression& matrix_expression::operator+=(double scalar) { for (int r=0; r < numRows(); r++) for (int c=0; c < numColumns(); c++) (*this)(r,c) += scalar; return *this; } //----------------------------------------------- // multiplication //----------------------------------------------- // returns a new matrix based on element-wise multiplication by a scalar // e.g., C = A*B; matrix matrix_expression::operator*(double scalar) const { // multiply each element by scalar matrix result = matrix(*this); for (int r=0; r < numRows(); r++) for (int c=0; c < numColumns(); c++) result(r,c) *= scalar; return result; } // in-place element-wise multiplication with a scalar // e.g., C = A*4; matrix_expression& matrix_expression::operator*=(double scalar) { for (int r=0; r < numRows(); r++) for (int c=0; c < numColumns(); c++) (*this)(r,c) *= scalar; return *this; } matrix matrix_expression::operator*(const matrix_expression& other) const { if (numColumns() != other.numRows()) throw invalid_argument("Inner matrix dimensions must agree."); matrix result = matrix(numRows(), other.numColumns()); // all zeros initially for (int r=0; r < numRows(); r++) for (int c=0; c < other.numColumns(); c++) for(int k=0; k < numColumns(); k++) // compute appropriate dot-product result(r,c) += (*this)(r,k) * other(k,c); return result; } //----------------------------------------------- // support for outputting a matrix expression //----------------------------------------------- ostream& operator<<(ostream& out, const matrix_expression& m) { string temp; unsigned int maxfield = 0; for (int r=0; r < m.numRows(); r++) { for (int c=0; c < m.numColumns(); c++) { stringstream s; s << fixed << setprecision(3); s << m(r,c); s >> temp; if (temp.size() > maxfield) maxfield = temp.size(); } } for (int r=0; r < m.numRows(); r++) { for (int c=0; c < m.numColumns(); c++) { stringstream s; s << fixed << setprecision(3); s << m(r,c); s >> temp; out << " " << setw(maxfield) << temp; } out << endl; } return out; } //----------------------------------------------- // support for scalar arithmetic as left-hand operand //----------------------------------------------- matrix operator+(double scalar, const matrix_expression& m) { return m + scalar; // reverse order of operands to invoke class method } matrix operator*(double scalar, const matrix_expression& m) { return m * scalar; // reverse order of operands to invoke class method }