001 // --- BEGIN LICENSE BLOCK ---
002 /*
003 * Copyright (c) 2009, Mikio L. Braun
004 * All rights reserved.
005 *
006 * Redistribution and use in source and binary forms, with or without
007 * modification, are permitted provided that the following conditions are
008 * met:
009 *
010 * * Redistributions of source code must retain the above copyright
011 * notice, this list of conditions and the following disclaimer.
012 *
013 * * Redistributions in binary form must reproduce the above
014 * copyright notice, this list of conditions and the following
015 * disclaimer in the documentation and/or other materials provided
016 * with the distribution.
017 *
018 * * Neither the name of the Technische Universit?t Berlin nor the
019 * names of its contributors may be used to endorse or promote
020 * products derived from this software without specific prior
021 * written permission.
022 *
023 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
024 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
025 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
026 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
027 * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
028 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
029 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
030 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
031 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
032 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
033 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
034 */
035 // --- END LICENSE BLOCK ---
036
037 package org.jblas;
038
039 /**
040 * <p>General functions which are geometric in nature.</p>
041 *
042 * <p>For example, computing all pairwise squared distances between all columns of a matrix.</p>
043 */
044 public class Geometry {
045
046 /**
047 * <p>Compute the pairwise squared distances between all columns of the two
048 * matrices.</p>
049 *
050 * <p>An efficient way to do this is to observe that <i>(x-y)^2 = x^2 - 2xy - y^2</i>
051 * and to then properly carry out the computation with matrices.</p>
052 */
053 public static DoubleMatrix pairwiseSquaredDistances(DoubleMatrix X, DoubleMatrix Y) {
054 if (X.rows != Y.rows)
055 throw new IllegalArgumentException(
056 "Matrices must have same number of rows");
057
058 DoubleMatrix XX = X.mul(X).columnSums();
059 DoubleMatrix YY = Y.mul(Y).columnSums();
060
061 DoubleMatrix Z = X.transpose().mmul(Y);
062 Z.muli(-2.0); //Z.print();
063 Z.addiColumnVector(XX);
064 Z.addiRowVector(YY);
065
066 return Z;
067 }
068
069 /** Center a vector (subtract mean from all elements (in-place). */
070 public static DoubleMatrix center(DoubleMatrix x) {
071 return x.subi(x.mean());
072 }
073
074 /** Center the rows of a matrix (in-place). */
075 public static DoubleMatrix centerRows(DoubleMatrix x) {
076 DoubleMatrix temp = new DoubleMatrix(x.columns);
077 for (int r = 0; r < x.rows; r++)
078 x.putRow(r, center(x.getRow(r, temp)));
079 return x;
080 }
081
082 /** Center the columns of a matrix (in-place). */
083 public static DoubleMatrix centerColumns(DoubleMatrix x) {
084 DoubleMatrix temp = new DoubleMatrix(x.rows);
085 for (int c = 0; c < x.columns; c++)
086 x.putColumn(c, center(x.getColumn(c, temp)));
087 return x;
088 }
089
090 /** Normalize a vector (scale such that its Euclidean norm is 1) (in-place). */
091 public static DoubleMatrix normalize(DoubleMatrix x) {
092 return x.divi(x.norm2());
093 }
094
095 /** Normalize the rows of a matrix (in-place). */
096 public static DoubleMatrix normalizeRows(DoubleMatrix x) {
097 DoubleMatrix temp = new DoubleMatrix(x.columns);
098 for (int r = 0; r < x.rows; r++)
099 x.putRow(r, normalize(x.getRow(r, temp)));
100 return x;
101 }
102
103 /** Normalize the columns of a matrix (in-place). */
104 public static DoubleMatrix normalizeColumns(DoubleMatrix x) {
105 DoubleMatrix temp = new DoubleMatrix(x.rows);
106 for (int c = 0; c < x.columns; c++)
107 x.putColumn(c, normalize(x.getColumn(c, temp)));
108 return x;
109 }
110
111 //BEGIN
112 // The code below has been automatically generated.
113 // DO NOT EDIT!
114
115 /**
116 * <p>Compute the pairwise squared distances between all columns of the two
117 * matrices.</p>
118 *
119 * <p>An efficient way to do this is to observe that <i>(x-y)^2 = x^2 - 2xy - y^2</i>
120 * and to then properly carry out the computation with matrices.</p>
121 */
122 public static FloatMatrix pairwiseSquaredDistances(FloatMatrix X, FloatMatrix Y) {
123 if (X.rows != Y.rows)
124 throw new IllegalArgumentException(
125 "Matrices must have same number of rows");
126
127 FloatMatrix XX = X.mul(X).columnSums();
128 FloatMatrix YY = Y.mul(Y).columnSums();
129
130 FloatMatrix Z = X.transpose().mmul(Y);
131 Z.muli(-2.0f); //Z.print();
132 Z.addiColumnVector(XX);
133 Z.addiRowVector(YY);
134
135 return Z;
136 }
137
138 /** Center a vector (subtract mean from all elements (in-place). */
139 public static FloatMatrix center(FloatMatrix x) {
140 return x.subi(x.mean());
141 }
142
143 /** Center the rows of a matrix (in-place). */
144 public static FloatMatrix centerRows(FloatMatrix x) {
145 FloatMatrix temp = new FloatMatrix(x.columns);
146 for (int r = 0; r < x.rows; r++)
147 x.putRow(r, center(x.getRow(r, temp)));
148 return x;
149 }
150
151 /** Center the columns of a matrix (in-place). */
152 public static FloatMatrix centerColumns(FloatMatrix x) {
153 FloatMatrix temp = new FloatMatrix(x.rows);
154 for (int c = 0; c < x.columns; c++)
155 x.putColumn(c, center(x.getColumn(c, temp)));
156 return x;
157 }
158
159 /** Normalize a vector (scale such that its Euclidean norm is 1) (in-place). */
160 public static FloatMatrix normalize(FloatMatrix x) {
161 return x.divi(x.norm2());
162 }
163
164 /** Normalize the rows of a matrix (in-place). */
165 public static FloatMatrix normalizeRows(FloatMatrix x) {
166 FloatMatrix temp = new FloatMatrix(x.columns);
167 for (int r = 0; r < x.rows; r++)
168 x.putRow(r, normalize(x.getRow(r, temp)));
169 return x;
170 }
171
172 /** Normalize the columns of a matrix (in-place). */
173 public static FloatMatrix normalizeColumns(FloatMatrix x) {
174 FloatMatrix temp = new FloatMatrix(x.rows);
175 for (int c = 0; c < x.columns; c++)
176 x.putColumn(c, normalize(x.getColumn(c, temp)));
177 return x;
178 }
179
180 //END
181 }